//#include "attention.h"
#include "cascade.h"
#include "dynet/dynet.h"
#include "dynet/training.h"
#include "dynet/lstm.h"
#include "dynet/dict.h"
#include "dynet/timing.h"
#include "dynet/expr.h"
#include "dynet/globals.h"
#include <dynet/io.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <cmath>
#include <tuple>

#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>

#include <map>
using namespace std;
using namespace dynet;


struct TripleLength
{
  TripleLength(const vector<vector<vector<float>>> & v, const vector<vector<int>> & w, const vector<vector<int>> & u) : vec(v), wec(w), uec(u) { }
  bool operator() (int i1, int i2);
  const vector<vector<vector<float>>> & vec;
  const vector<vector<int>> & wec;
  const vector<vector<int>> & uec;
};

bool TripleLength::operator() (int i1, int i2) {
  if(vec[i2].size() != vec[i1].size()) return (vec[i2].size() < vec[i1].size());
  else if(wec[i2].size() != wec[i1].size()) return (wec[i2].size() < wec[i1].size());
  return (uec[i2].size() < uec[i1].size());
}

inline void CreateMinibatches(const vector<vector<vector<float>>> & train_src,
                              const vector<vector<int>> & train_int,
                              const vector<vector<int>> & train_trg,
                              size_t max_size,
                              vector<vector<vector<vector<float>>> > & train_src_minibatch,
                              vector<vector<vector<int>> > & train_int_minibatch,
                              vector<vector<vector<int>> > & train_trg_minibatch,
                              vector<size_t> & train_ids_minibatch) {
  // Clear the vectors
  train_src_minibatch.clear();
  train_int_minibatch.clear();
  train_trg_minibatch.clear();
  
  // Get the sentence ids
  std::vector<size_t> train_ids(train_trg.size());
  
  std::iota(train_ids.begin(), train_ids.end(), 0);
  
  // Sort in descending length order 
  if(max_size > 1)
    sort(train_ids.begin(), train_ids.end(), TripleLength(train_src, train_int, train_trg));
  
  vector<vector<vector<float>>> train_src_next;
  vector<vector<int>> train_int_next;
  vector<vector<int>> train_trg_next;
  
  
  for(size_t i = 0; i < train_ids.size(); i++) {
    // Get maximum length for this batch
    
    train_src_next.push_back(train_src[train_ids[i]]);
    train_int_next.push_back(train_int[train_ids[i]]);
    train_trg_next.push_back(train_trg[train_ids[i]]);
    
    //if((train_trg_next.size()+1) * max_len > max_size) {
    if((train_trg_next.size()+1) > max_size) {
      train_src_minibatch.push_back(train_src_next);
      train_src_next.clear();
      train_int_minibatch.push_back(train_int_next);
      train_int_next.clear();
      train_trg_minibatch.push_back(train_trg_next);
      train_trg_next.clear();
      
    }
  }
  // Take care of last sentences
  if(train_trg_next.size()) {
    train_src_minibatch.push_back(train_src_next);
    train_int_minibatch.push_back(train_int_next);
    train_trg_minibatch.push_back(train_trg_next);
  }

  // Create a sentence list for this minibatch
  train_ids_minibatch.resize(train_src_minibatch.size());
  std::iota(train_ids_minibatch.begin(), train_ids_minibatch.end(), 0);
  return;
}

LookupParameter input_lookup;
LSTMBuilder enc_fwd_lstm;
LSTMBuilder enc_bwd_lstm;
LSTMBuilder enc_compress_lstm_1, enc_compress_lstm_2;
LSTMBuilder dec_1_lstm;
LSTMBuilder dec_2_lstm;

float learning_rate = 0.0002;
dynet::real learning_scale = 1.0;
float rate_decay = 0.8;
float rate_threshold = 1e-5;

float DROPOUT = 0.0;
bool USE_REG;
bool ASR_ONLY = false;

unsigned int SRC_VOCAB_SIZE = 0;
unsigned int INT_VOCAB_SIZE = 0;
unsigned int TRG_VOCAB_SIZE = 0;

int DEV_LIMIT = 10000;
int TEST_LIMIT = 10000;
int TRAIN_LIMIT = 5000000;
int MAX_LEN = 80;
int MAX_EPOCHS = 200;
int BATCH_SIZE = 8;
int report_every_i;

int BEAMSIZE = 4;

float int_unk_log_prob_, trg_unk_log_prob_;
int INT_UNK_ID, TRG_UNK_ID;

double best = 9e+99;

dynet::Dict src_d;
dynet::Dict int_d;
dynet::Dict trg_d;

int kSOS;
int kEOS;

unsigned ENC_LSTM_NUM_OF_LAYERS;
unsigned DEC_LSTM_NUM_OF_LAYERS;
unsigned REP_SIZE = 512;
unsigned STATE_SIZE = 512;
unsigned STATE_SIZE1 = 128;
unsigned ATTENTION_SIZE = 512;
unsigned EMBEDDINGS_SIZE = 39;

unsigned NUM_LAYERS = 2;
unsigned FEAT_SIZE = 39;

float REG_WEIGHT=0.1;
float LENGTH_NORM_WEIGHT=0.8;

string SPEECHDIR, TEXTDIR1, TEXTDIR2;
string TESTSPEECHDIR, TESTTEXTDIR1, TESTTEXTDIR2;
string SPEECHEND = ".plp";
string TEXT1END = ".mb.cleaned.split";
string TEXT2END = ".fr.cleaned.split";

unsigned MAXSOURCE = 2400;
bool STACK_FEATS = false;


int main(int argc, char **argv) {

  dynet::initialize(argc, argv);

  if (argc < 5) {
    cerr << "Usage: " << argv[0] << " model_type train-files.txt dev-files.txt [--speech_dir speechdir] [--text_dir1 textdir1] [--text_dir2 textdir2]" << endl;
    cerr << "\t\t\t [--in model_in] [--partial model_in] [--out model_out] [--test test.txt] [--attention] [--stack] [--input_type 2]" << endl;
    cerr << "\t\t\t [--layers N (def: 2)] [--size K (def: 512] [--epochs L (def: 80)] [--beam M (def: 4)] [--dropout M (def: 0.0)] [--maxsource M (def 20000)]" << endl;
    cerr << "\t\t\t [--speech_dir SPEECHDIR] [--text_dir1 TEXTDIR1] [--text_dir2 TEXTDIR2]" << endl;
    cerr << "\t\t\t [--speech_end SPEECHEND] [--text1_end TEXT1END] [--text2_end TEXT2END]" << endl;
    cerr << "\t\t\t [--test_speech_dir TESTSPEECHDIR] [--test_text_dir1 TESTTEXTDIR1] [--test_text_dir2 TESTTEXTDIR2]" << endl;
    return 1;
  }
  // See what type of model we'll use
  std::vector< multitask* > my_models;
  string modeltype = string(argv[1]);

  if (modeltype == "triangle") {
    my_models.push_back( new triangle());
    USE_REG = false;
  }
  else if (modeltype == "triangle_reg") {
    my_models.push_back( new triangle());
    USE_REG = true;
  }
  else if (modeltype == "cascade") {
    my_models.push_back( new cascade());
    USE_REG = false;
  }
  else if (modeltype == "cascade_reg") {
    my_models.push_back( new cascade());
    USE_REG = true;
  }
  else if (modeltype == "simplemultitask") {
    my_models.push_back( new simplemultitask());
  }
  else if (modeltype == "simpleunitask") {
    my_models.push_back( new simpleunitask());
  }
  else {
    cerr << "model_type needs to be one of [simpleunitask, simplemultitask, cascade, cascade_reg, triangle, triangle_reg]" << endl << endl;
    cerr << "Usage: " << argv[0] << " model_type train-files.txt dev-files.txt [--speech_dir speechdir] [--text_dir1 textdir1] [--text_dir2 textdir2]" << endl;
    cerr << "\t\t\t [--in model_in] [--partial model_in] [--out model_out] [--test test.txt] [--attention] [--stack] [--input_type 2]" << endl;
    cerr << "\t\t\t [--layers N (def: 2)] [--size K (def: 512] [--epochs L (def: 80)] [--beam M (def: 4)] [--dropout M (def: 0.0)] [--maxsource M (def 20000)]" << endl;
    cerr << "\t\t\t [--speech_dir SPEECHDIR] [--text_dir1 TEXTDIR1] [--text_dir2 TEXTDIR2]" << endl;
    cerr << "\t\t\t [--speech_end SPEECHEND] [--text1_end TEXT1END] [--text2_end TEXT2END]" << endl;
    cerr << "\t\t\t [--test_speech_dir TESTSPEECHDIR] [--test_text_dir1 TESTTEXTDIR1] [--test_text_dir2 TESTTEXTDIR2]" << endl;
    return 1;
  }

  bool partial_in = false;
  bool full_in = false;
  bool testmode = false;
  bool attentionmode = false;
  int inputtype = 3;
  string modelin, testfile, modelout, attentionout;

  for (int i = 4; i < argc; i++) {
    if (string(argv[i]) == "--in"){
      full_in = true;
      modelin = argv[i+1];
      if (partial_in){
        cerr << "Can't have both --partial and --in for inputs\n";
        return 1;
      }
    }
    if (string(argv[i]) == "--partial"){
      partial_in = true;
      modelin = argv[i+1];
      if (full_in){
        cerr << "Can't have both --partial and --in for inputs\n";
        return 1;
      }
    }
    if (string(argv[i]) == "--out"){
      modelout = string(argv[i+1]);
    }
    if (string(argv[i]) == "--attention"){
      attentionmode = true;
    }
    if (string(argv[i]) == "--asronly"){
      ASR_ONLY = true;
    }
    if (string(argv[i]) == "--stack"){
      STACK_FEATS = true;
      EMBEDDINGS_SIZE = 64;
      STATE_SIZE=256;
    }
    if (string(argv[i]) == "--test"){
      testmode = true;
      testfile = string(argv[i+1]);
    }
    if (string(argv[i]) == "--input_type"){
      inputtype = atoi(argv[i+1]);
      if (inputtype != 2 && inputtype != 3){
		    cerr << "input_type has to be either 2 or 3\n";
        return 1;
      }
    }
    if (string(argv[i]) == "--layers")
      NUM_LAYERS = atoi(argv[i+1]);
    if (string(argv[i]) == "--size")
		  REP_SIZE = atoi(argv[i+1]);
    if (string(argv[i]) == "--epochs")
		  MAX_EPOCHS = atoi(argv[i+1]);
    if (string(argv[i]) == "--maxsoure")
		  MAXSOURCE = atoi(argv[i+1]);
    if (string(argv[i]) == "--beam")
		  BEAMSIZE = atoi(argv[i+1]);
    if (string(argv[i]) == "--batchsize")
		  BATCH_SIZE = atoi(argv[i+1]);
    if (string(argv[i]) == "--dropout")
      DROPOUT = atof(argv[i+1]);
    if (string(argv[i]) == "--speech_dir")
      SPEECHDIR = string(argv[i+1]);
    if (string(argv[i]) == "--speech_end")
      SPEECHEND = string(argv[i+1]);
    if (string(argv[i]) == "--text_dir1")
      TEXTDIR1 = string(argv[i+1]);
    if (string(argv[i]) == "--text1_end")
      TEXT1END = string(argv[i+1]);
    if (string(argv[i]) == "--text_dir2")
      TEXTDIR2 = string(argv[i+1]);
    if (string(argv[i]) == "--text2_end")
      TEXT2END = string(argv[i+1]);
    if (string(argv[i]) == "--test_speech_dir")
      TESTSPEECHDIR = string(argv[i+1]);
    if (string(argv[i]) == "--test_text_dir1")
      TESTTEXTDIR1 = string(argv[i+1]);
    if (string(argv[i]) == "--test_text_dir2")
      TESTTEXTDIR2 = string(argv[i+1]);
  }
  if (testmode && partial_in){
    cerr << "Need --in for testmode (--test), not --partial\n";
    return 1;
  }
  if (!testmode && modelout.empty()){
    cerr << "Need to specify a string in order to save the model with --out [model_out] since you are training\n";
    return 1;
  }
  if (inputtype == 2 && modeltype != "simpleunitask"){
    cerr << "input_type 2 only works for a simpleunitask model\n";
    return 1;
  }

  kSOS = src_d.convert("<s>");
  kEOS = src_d.convert("</s>");
  int_d.convert("<s>");
  int_d.convert("</s>");
  trg_d.convert("<s>");
  trg_d.convert("</s>");

  vector<vector<int>> train_int, train_trg, dev_int, dev_trg, test_int, test_trg;
  vector<vector<vector<float>>> train_src, dev_src, test_src;

  string line;
  int tlc = 0;
  int stoks = 0;
  int itoks = 0;
  int ttoks = 0;
  cerr << "Reading training data from " << argv[2] << "...\n";
  {
    ifstream in(argv[2]);
    assert(in);
    while(getline(in, line)) {
      vector<vector<float>> x;
      vector<int> y,z;

      string spfile = SPEECHDIR + '/' + line + SPEECHEND;
	  if (!testmode)
		x = read_features(spfile);

      string txtfile1 = TEXTDIR1 + '/' + line + TEXT1END;
      std::ifstream intemp(txtfile1);
      assert(intemp);
      string line1;
      getline(intemp, line1);
      y = read_sentence(line1, int_d);

      if (!(modeltype == "simpleunitask")){
        string txtfile2 = TEXTDIR2 + '/' + line + TEXT2END; 
        std::ifstream intemp2(txtfile2);
        assert(intemp);
        string line2;
        getline(intemp2, line2);
        z = read_sentence(line2, trg_d);
      }

      ++tlc;
	  if (x.size() > 0 && y.size() > 0 && x.size() < MAXSOURCE){
		train_src.push_back(x);
		train_int.push_back(y);
		train_trg.push_back(z);
		stoks += x.size();
		itoks += y.size();
		ttoks += z.size();
	  }
    }
    cerr << tlc << " source lines, " << stoks << " PLP features.\n";
    cerr << "\t" << itoks << " interim tokens, " << int_d.size() << " target types "<< endl;
    cerr << "\t" << ttoks << " target tokens, " << trg_d.size() << " target types "<< endl;
  }

  int_d.freeze(); // no new interim word types allowed
  trg_d.freeze(); // no new target word types allowed
  int_d.set_unk("<unk>");
  trg_d.set_unk("<unk>");
  INT_VOCAB_SIZE = int_d.size();
  TRG_VOCAB_SIZE = trg_d.size();

  INT_UNK_ID = int_d.get_unk_id();
  int_unk_log_prob_ = -log(INT_VOCAB_SIZE);
  TRG_UNK_ID = trg_d.get_unk_id();
  trg_unk_log_prob_ = -log(TRG_VOCAB_SIZE);

  vector<vector< vector< vector<float> > > > train_src_minibatch;
  vector<vector< vector<int> > > train_int_minibatch;
  vector<vector< vector<int> > > train_trg_minibatch;
  vector<size_t> train_ids_minibatch;
  vector<vector<int>> empty_minibatch;
  //std::vector<vector<int>> empty_cache;
  size_t minibatch_size = BATCH_SIZE;

  int dlc = 0;
  int dstoks = 0;
  int ditoks = 0;
  int dttoks = 0;
  cerr << "Reading dev data from " << argv[3] << "...\n";
  {
    ifstream in(argv[3]);
    assert(in);
    while(getline(in, line)) {
      vector<vector<float>> x;
      vector<int> y,z;

      string spfile = SPEECHDIR + '/' + line + SPEECHEND;
	  if (!testmode)
		x = read_features(spfile);

      string txtfile1 = TEXTDIR1 + '/' + line + TEXT1END;
      std::ifstream intemp(txtfile1);
      assert(intemp);
      string line1;
      getline(intemp, line1);
      y = read_sentence(line1, int_d);

      if (!(modeltype == "simpleunitask")){
        string txtfile2 = TEXTDIR2 + '/' + line + TEXT2END;
        std::ifstream intemp2(txtfile2);
        assert(intemp);
        string line2;
        getline(intemp2, line2);
        z = read_sentence(line2, trg_d);
      }

	  if (x.size() > 0 && y.size() > 0 && x.size() < MAXSOURCE){
		++dlc;
		dev_src.push_back(x);
		dev_int.push_back(y);
		dev_trg.push_back(z);
		dstoks += x.size();
		ditoks += y.size();
		dttoks += z.size();
	  }
    }
    cerr << dlc << " dev lines " << dstoks << " features " << ditoks << " int tokens " << dttoks << "target tokens " << endl;
  }
  std::vector<int> dev_ids(dev_src.size());
  std::iota(dev_ids.begin(), dev_ids.end(), 0);

  
  // Define model and trainer
  ParameterCollection model;
  AdamTrainer trainer(model, learning_rate);
  trainer.sparse_updates_enabled = false;
  
  // Build the model
  if (partial_in){
    my_models[0]->initialize_partial(model);
    cerr << "Reading partial model in from " << modelin << "...\n";
    TextFileLoader l(modelin);
    l.populate(model);
    my_models[0]->initialize_extra(model);

  }
  else
    my_models[0]->initialize(model);

  if (testmode){
    cerr << "Reading in model from " << modelin << "...\n";
    TextFileLoader l(modelin);
    l.populate(model);
  }
  
  if (testmode){
    int testlc = 0;
    cerr << "Reading test data from " << testfile << "...\n";
    {
      ifstream in(testfile);
      assert(in);
      while(getline(in, line)) {
        vector<vector<float>> x;
        vector<int> y,z;

        string spfile = TESTSPEECHDIR + '/' + line + SPEECHEND;
        x = read_features(spfile);
        test_src.push_back(x);
        ++testlc;

        if (attentionmode){
          string txtfile1 = TESTTEXTDIR1 + '/' + line + TEXT1END;
          std::ifstream intemp(txtfile1);
          assert(intemp);
          string line1;
          getline(intemp, line1);
          y = read_sentence(line1, int_d);

          if (!(modeltype == "simpleunitask")){
            string txtfile2 = TESTTEXTDIR2 + '/' + line + TEXT2END;
            std::ifstream intemp2(txtfile2);
            assert(intemp);
            string line2;
            getline(intemp2, line2);
            z = read_sentence(line2, trg_d);
          }
          test_int.push_back(y);
          test_trg.push_back(z);
        }
      }
	  cerr << testlc << " test lines\n";
    }
  }

  if (!testmode){
	  CreateMinibatches(train_src,
					  train_int,
					  train_trg,
					  minibatch_size,
					  train_src_minibatch,
					  train_int_minibatch,
					  train_trg_minibatch,
					  train_ids_minibatch);
   
     // Create a list with the minibatch ids
    std::vector<int> train_ids(train_src_minibatch.size());
    std::iota(train_ids.begin(), train_ids.end(), 0);
     
    cerr << " The training set was made into " << train_src_minibatch.size() << " minibatches" << endl;
     
    report_every_i = train_src_minibatch.size();
    if (report_every_i > 500000)
      report_every_i = 500000;
    
    unsigned si = train_ids.size();
    bool first = true;
    int report = 0;
    unsigned lines = 0;
    int epoch = 0;
    float best_bleu = -1;

    // Training
    while((lines/ (double)train_ids.size()) < MAX_EPOCHS) {
      Timer iteration("completed in");
      double loss = 0;
      unsigned ttokens = 0;
      for (unsigned i = 0; i < report_every_i; ++i) {
        if (si == train_ids.size()) {
          si = 0;
          if (first) { first = false; } else { ++epoch;}
          shuffle(train_ids.begin(), train_ids.end(), *rndeng);
          cerr << "**SHUFFLE\n";
        }
        // build graph for this instance
        int batch_id = train_ids[si++];
		    loss += my_models[0]->train(model, train_src_minibatch[batch_id], train_int_minibatch[batch_id], train_trg_minibatch[batch_id], trainer, learning_scale);
        ++lines;
		    // Not an exact calculation of the number of tokens:
        ttokens += train_int_minibatch[batch_id].size() * train_int_minibatch[batch_id][0].size() + train_trg_minibatch[batch_id].size() * train_trg_minibatch[batch_id][0].size();;
      }
      trainer.status();
      float ppl = pow(2, (loss / ttokens / log(2)));
      cerr << " E = " << (loss / ttokens) << " PPL = " << ppl <<  endl;

      // DEV
      report++;
      double dbleu = 0;
      for (auto& k : dev_ids) {
        dbleu += my_models[0]->test_dev(model, dev_src[k], dev_int[k], dev_trg[k]);
      }
      float last_bleu = dbleu / dev_ids.size();
      float epoch = lines / (double)train_ids.size();
      cerr << "\n***DEV [epoch=" << epoch << "] BLEU = " << last_bleu << endl;

      // Save the last checkpoint, duplicate if best
      if (last_bleu > best_bleu){
        best_bleu = last_bleu;
        ostringstream os1;
        os1 << modelout << ".best.params";
        string fname = os1.str();
        ofstream out(fname);
        TextFileSaver s(fname);
        s.save(model);
      }
      ostringstream os1;
      os1 << modelout << ".last.params";
      string fname = os1.str();
      ofstream out(fname);
      TextFileSaver s(fname);
      s.save(model);
    }
  }
  
  if (attentionmode) {
    // Disable dropout
    enc_fwd_lstm.disable_dropout();
    enc_bwd_lstm.disable_dropout();
    enc_compress_lstm_1.disable_dropout();
    enc_compress_lstm_2.disable_dropout();
    dec_1_lstm.disable_dropout();
    dec_2_lstm.disable_dropout();

    std::vector<int> test_ids(test_src.size());
    std::iota(test_ids.begin(), test_ids.end(), 0);

    for (auto& k : test_ids) {
      my_models[0]->dump_attentions(model, test_src[k], test_int[k], test_trg[k]);
      cout << endl;
    }
  }
  else if (testmode) {
	// Disable dropout
	enc_fwd_lstm.disable_dropout();
	enc_bwd_lstm.disable_dropout();
	enc_compress_lstm_1.disable_dropout();
	enc_compress_lstm_2.disable_dropout();
	dec_1_lstm.disable_dropout();
	dec_2_lstm.disable_dropout();
    
	std::vector<int> test_ids(test_src.size());
    std::iota(test_ids.begin(), test_ids.end(), 0);

    for (auto& k : test_ids) {
        cout << k << '\t';
        my_models[0]->test(model, test_src[k], BEAMSIZE);
    }
  }

}


/* Definition of general multitask Expressions for all models
 *
 *
 */

// Embed one sentence
vector<Expression> multitask::embed_sentence(const vector<int>& sentence, ComputationGraph& cg) {
  vector<Expression> output_exprs(sentence.size() + 2); //character encoding
  int index = 0;
  output_exprs.at(index) = lookup(cg, input_lookup, kSOS);
  index++;
  for (auto c : sentence) {
    output_exprs.at(index) = lookup(cg, input_lookup, c);
    index++;
  }
  output_exprs.at(index) = lookup(cg, input_lookup, kEOS);
  return output_exprs;
}

// Embed batch version
vector<Expression> multitask::embed_sentence(const vector<vector<int>>& sentences, ComputationGraph& cg) {

  // The first sentence of the minibatch should be the longest
  const unsigned slen = sentences[0].size();

  vector<Expression> output_exprs(slen + 2); //character encoding
  
  // Initialize all sentences with SOS
  vector<unsigned> words(sentences.size(), kSOS);

  int index = 0;
  output_exprs.at(index) = lookup(cg, input_lookup, words);

  // The equal in the first for ensures all sentences get a kEOS
  for(int t = 0; t <= slen; t++) {
    for(size_t i = 0; i < sentences.size(); i++)
      words[i] = (t < sentences[i].size() ? sentences[i][t] : kEOS);
    output_exprs.at(++index) = lookup(cg, input_lookup, words);
  }

  return output_exprs;
}

// Embed one audio sequence
vector<Expression> multitask::embed_features(const vector<vector<float>>& features, ComputationGraph& cg) {
  vector<Expression> output_exprs(features.size());
  int index = 0;
  for (auto f : features) {
    output_exprs.at(index) = input(cg, { FEAT_SIZE }, f);
    index++;
  }
  return output_exprs;
}

// Embed one audio sequence with stacking
vector<Expression> multitask::embed_stack_features(const vector<vector<float>>& features, ComputationGraph& cg) {
  int index = 0;
  int N = features.size()/4-1;
  vector<Expression> output_exprs(N);
  vector<float> feats(FEAT_SIZE * 8);

  for (size_t i = 0; i + 7 < features.size(); i += 4){
    for (int j = 0; j < 8; j += 1){
      for (size_t k = 0; k < FEAT_SIZE; k++){
        feats[j*FEAT_SIZE + k] = features[i+j][k];
      }
    }
    output_exprs.at(index++) = input(cg, {FEAT_SIZE*8}, feats);
  }
  return output_exprs;
}


// Embed audio, batch version
vector<Expression> multitask::embed_features(const vector<vector<vector<float>>>& features, ComputationGraph& cg) {

  // The first sentence of the minibatch should be the longest
  const unsigned slen = features[0].size();

  vector<Expression> output_exprs(slen);
  
  // Initialize all vectors wirh 0.0
  vector<float> feats(features.size()*FEAT_SIZE);

  int index = 0;
  for(int t = 0; t < slen; t++) {
    for(size_t i = 0; i < features.size(); i++){
      for(size_t j = 0; j < FEAT_SIZE; j++){
        feats[i*FEAT_SIZE + j] = (t < features[i].size() ? features[i][t][j] : 0.0);
      }
    }
    output_exprs.at(index++) = input(cg, Dim({FEAT_SIZE},features.size()),feats);
  }

  return output_exprs;
}

// Embed audio with stacked features, batch version
vector<Expression> multitask::embed_stack_features(const vector<vector<vector<float>>>& features, ComputationGraph& cg) {

  // The first sentence of the minibatch should be the longest
  const unsigned slen = features[0].size();
  vector<Expression> output_exprs(slen/4-1);
  
  // Initialize all vectors wirh 0.0
  vector<float> feats(features.size()*FEAT_SIZE*8);

  int index = 0;
  for(int t = 0; t + 7 < slen; t += 4){
    for(int k = 0; k < 8; k++){
      for(size_t i = 0; i < features.size(); i++){
        for(size_t j = 0; j < FEAT_SIZE; j++){
          feats[i*FEAT_SIZE*8 + k*FEAT_SIZE + j] = (t + k < features[i].size() ? features[i][t+k][j] : 0.0);
        }
      }
    }
    output_exprs.at(index++) = input(cg, Dim({FEAT_SIZE*8}, features.size()), feats);
  }

  return output_exprs;
}


vector<Expression> multitask::run_lstm(LSTMBuilder& init_state, const vector<Expression>& input_vecs) {
  LSTMBuilder& s = init_state;
  vector<Expression> out_vectors;
  vector<Expression>::const_iterator input_vecs_it;
  for (input_vecs_it = input_vecs.begin(); input_vecs_it != input_vecs.end(); input_vecs_it++) {
    s.add_input(*input_vecs_it);        //run lstm through the inputs
    out_vectors.push_back(s.back());
  }
  return out_vectors;
}

vector<Expression> multitask::encode_sentence(LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, vector<Expression>& embedded) {
  vector<Expression> fwd_vectors = run_lstm(enc_fwd_lstm, embedded);    //forward lstm encoding
  vector<Expression> embedded_rev;
  for (vector<Expression>::reverse_iterator i = embedded.rbegin(); i != embedded.rend(); ++i) {
    embedded_rev.push_back(*i);
  }
  vector<Expression> bwd_vectors = run_lstm(enc_bwd_lstm, embedded_rev);    //backward lstm encoding
  reverse(bwd_vectors.begin(), bwd_vectors.end());
  vector<Expression> encoded;
  for (auto loop_index = 0U; loop_index < fwd_vectors.size(); loop_index++) {
    encoded.push_back(concatenate( { fwd_vectors.at(loop_index), bwd_vectors.at(loop_index) }));    //bi-lstm encoding
  }
  return encoded;
}

// Pyramidal encoding for speech features
vector<Expression> multitask::encode_features(LSTMBuilder& enc_fwd_lstm, 
                                               LSTMBuilder& enc_bwd_lstm, 
                                               LSTMBuilder& enc_compress_lstm_1,
                                               LSTMBuilder& enc_compress_lstm_2,
                                               vector<Expression>& embedded) {
  vector<Expression> fwd_vectors = run_lstm(enc_fwd_lstm, embedded);    //forward lstm encoding
  vector<Expression> embedded_rev;
  for (vector<Expression>::reverse_iterator i = embedded.rbegin(); i != embedded.rend(); ++i) {
    embedded_rev.push_back(*i);
  }
  vector<Expression> bwd_vectors = run_lstm(enc_bwd_lstm, embedded_rev);    //backward lstm encoding
  reverse(bwd_vectors.begin(), bwd_vectors.end());
  vector<Expression> encoded_1;

  unsigned N = fwd_vectors.size();
  for (auto loop_index = 0U; loop_index < N; loop_index++) {
    if ((loop_index % 2 == 1) || (loop_index == N-1))
      encoded_1.push_back(concatenate( { fwd_vectors.at(loop_index), bwd_vectors.at(loop_index) }));    //bi-lstm encoding
  }
  vector<Expression> encoded_2 = run_lstm(enc_compress_lstm_1, encoded_1);

  vector<Expression> encoded_2_skip;
  unsigned M = encoded_2.size();
  for (auto loop_index = 0U; loop_index < M; loop_index++) {
    if ((loop_index % 2 == 1) || (loop_index == M-1))
      encoded_2_skip.push_back(encoded_2.at(loop_index)); 
  }

  vector<Expression> encoded_3 = run_lstm(enc_compress_lstm_2, encoded_2_skip);
  return encoded_3;
}

// Used for bleu calculations
map<vector<int>,int> multitask::get_ngrams(const vector<int>& sentence){
  vector<int> ngram;
  map<vector<int>,int>  all_ngrams;
  int ngram_order_ = 4;
  for (int k = 0; k < ngram_order_; k++) {
    for(int i =0; i < max((int)sentence.size()-k,0); i++){
      for ( int j = i; j<= i+k; j++){
        ngram.push_back(sentence[j]);
      }
      all_ngrams[ngram] ++;
      ngram.clear();
    }
  }
  return all_ngrams;
}

float multitask::bleu(const vector<int>& hyp, const vector<int>& ref){
  // Get ngrams up to 4 for hyp
  map<vector<int>,int>  hyp_ngrams = get_ngrams(hyp);
  map<vector<int>,int>  ref_ngrams = get_ngrams(ref);

  // Create vector to hold statistics
  float ref_len = ref.size();
  float hyp_len = hyp.size();
  int ngram_order_ = 4;
  int vals_n = 3*ngram_order_;
  vector<int> vals(vals_n);
  // ... and initialize it to 0
  for (int i =0; i<ngram_order_; i++) {
    vals[3*i] = 0;
    vals[3*i+1] = max((int)hyp_len-i,0);
    vals[3*i+2] = max((int)ref_len-i,0);
  }

  // Find matches
  for (map<vector<int>,int>::const_iterator it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++){
    map<vector<int>,int>::const_iterator ref_it = ref_ngrams.find(it->first);
    if(ref_it != ref_ngrams.end()){
      vals[3 * (it->first.size()-1)] += min(ref_it->second,it->second);
    }
  }

  // Calculate the precision for each order
  float tot_prec = 0.0;
  for (int i=0; i < ngram_order_; i++) {
        float num = (vals[3*i]);
        float denom = (vals[3*i+1]);
        float prec = (denom ? num/denom : 0);
        tot_prec += (prec ? prec : 0);
    }
    tot_prec /= ngram_order_;

  // Brevity penalty
  float bp = 1.0-(float)ref_len/hyp_len;
  if (bp < 0) {
    tot_prec *= exp(bp);
  }
  return tot_prec;
}


/* Definition of triangle model
 *
 *
 */
void triangle::initialize(ParameterCollection& model) {

  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  dec_2_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE * 2 + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  dec_2_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  output_2_lookup = model.add_lookup_parameters(TRG_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  attention_2_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_2_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_2_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  attention_3_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE  });
  attention_3_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_3_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });

  decoder_2_w = model.add_parameters( { TRG_VOCAB_SIZE, STATE_SIZE });
  decoder_2_b = model.add_parameters( { TRG_VOCAB_SIZE });
}


void triangle::initialize_partial(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });
}

void triangle::initialize_extra(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  dec_2_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE * 2 + EMBEDDINGS_SIZE , STATE_SIZE, model);
  dec_2_lstm.set_dropout(DROPOUT);

  output_2_lookup = model.add_lookup_parameters(TRG_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_2_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_2_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_2_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  attention_3_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_3_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_3_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_2_w = model.add_parameters( { TRG_VOCAB_SIZE, STATE_SIZE });
  decoder_2_b = model.add_parameters( { TRG_VOCAB_SIZE });
}



Expression triangle::attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_1_w2);
  Expression v = parameter(cg, attention_1_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}

Expression triangle::attend_2(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_2_w2);
  Expression v = parameter(cg, attention_2_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}

Expression triangle::attend_3(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_3_w2);
  Expression v = parameter(cg, attention_3_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}


Expression triangle::reg_joint(vector<Expression>& att_21, vector<Expression>& att_32, vector<Expression>& att_31){
  vector<Expression> reg_loss;
  int K = att_31.size();
  
  Expression a21 = concatenate_cols(att_21);
  Expression a32 = concatenate_cols(att_32);

  Expression a31_pivot = a21 * a32;
  for (int k=0; k < K; k++){
    Expression dist = squared_distance(att_31[k],pick(a31_pivot,k,1));
    reg_loss.push_back(dist);      
  }
  return sum_batches(sum(reg_loss));
}

// One sentence version
Expression triangle::decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  for (int c : trg_sentence) {
    embeddings_2.push_back(c);
  }
  embeddings_2.push_back(kEOS);

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;
  
  dec_1_lstm.add_input(concatenate( { encoded.back() , last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    // Store the attentions for possible regularization
    att_1.push_back(att_weights_1);
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }
  Expression lossint = sum(loss);

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);
  Expression w1_3 = parameter(cg, attention_3_w1);

  vector<Expression> loss2;

  Expression input_mat_2 = concatenate_cols(decoder_1_states);
  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;
  Expression w1dt_3 = w1_3 * input_mat_2;

  dec_2_lstm.add_input( concatenate( { encoded.back(), decoder_1_states.back(), last_output_embeds } ) );

  vector<Expression> att_2;
  vector<Expression> att_3;


  for (int c: embeddings_2){
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
    Expression context_2 = input_mat * att_weights_2;
    Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
    Expression context_3 = input_mat_2 * att_weights_3;
    
    // Store attentions for possible regularization
    att_2.push_back(att_weights_2);
    att_3.push_back(att_weights_3);
    
    Expression vector = concatenate( { context_2, context_3, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;
    Expression probs = softmax(out_vector);
    last_output_embeds = lookup(cg, output_2_lookup, c);
    loss2.push_back(-log(pick(probs, c)));
  }
  Expression losstrg = sum(loss2);

  if (USE_REG){
    Expression regularizer = reg_joint(att_1,att_3,att_2);
    return lossint + losstrg + REG_WEIGHT*regularizer;
  }
  return lossint + losstrg;
}


//Version that dumps attentions
void triangle::decode_attentions(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  for (int c : trg_sentence) {
    embeddings_2.push_back(c);
  }
  embeddings_2.push_back(kEOS);

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;
  cout << "Attention 1" << endl;
  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    cout << as_vector(att_weights_1.value()) << endl;
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);
  Expression w1_3 = parameter(cg, attention_3_w1);

  vector<Expression> loss2;

  Expression input_mat_2 = concatenate_cols(decoder_1_states);
  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;
  Expression w1dt_3 = w1_3 * input_mat_2;

  vector<dynet::real> y_values(STATE_SIZE);

  dec_2_lstm.add_input( concatenate( { encoded.back(), decoder_1_states.back(), last_output_embeds } ) );

  vector<Expression> att_2;
  vector<Expression> att_3;

  cout << endl << "Attention 2 and 3" << endl;
  vector<vector<dynet::real>> attention_dump3;
  for (int c: embeddings_2){
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
    Expression context_2 = input_mat * att_weights_2;
    Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
    Expression context_3 = input_mat_2 * att_weights_3;
    
    att_2.push_back(att_weights_2);
    att_3.push_back(att_weights_3);
    cout << as_vector(att_weights_2.value()) << endl;
    cout << as_vector(att_weights_3.value()) << endl;

    Expression vector = concatenate( { context_2, context_3, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;
    Expression probs = softmax(out_vector);
    last_output_embeds = lookup(cg, output_2_lookup, c);
    loss2.push_back(-log(pick(probs, c)));
  }
  return;
}

// Batch version
Expression triangle::decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg) {

  // The first output sentence should be the longest
  const unsigned max_len_1 = int_batch[0].size();
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  // Initialize all sentences with SOS and get them as input
  vector<unsigned> words(int_batch.size(), kSOS);
  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;
  
  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for(int t = 0; t <= max_len_1; t++) {
    for(size_t i = 0; i < int_batch.size(); i++)
      words[i] = (t < int_batch[i].size() ? int_batch[i][t] : kEOS);

    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    last_output_embeddings = lookup(cg, output_1_lookup, words);
    loss.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }

  Expression lossint = sum_batches(sum(loss));

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);
  Expression w1_3 = parameter(cg, attention_3_w1);

  vector<Expression> loss2;

  Expression input_mat_2 = concatenate_cols(decoder_1_states);

  vector<unsigned> words_2(trg_batch.size(), kSOS);

  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;
  Expression w1dt_3 = w1_3 * input_mat_2;

  dec_2_lstm.add_input( concatenate( { encoded.back(), decoder_1_states.back(), last_output_embeds } ) );

  vector<Expression> att_2;
  vector<Expression> att_3;


  for(int t = 0; t <= max_len_1; t++) {
    for(size_t i = 0; i < trg_batch.size(); i++)
      words[i] = (t < trg_batch[i].size() ? trg_batch[i][t] : kEOS);

    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
    Expression context_2 = input_mat * att_weights_2;
    Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
    Expression context_3 = input_mat_2 * att_weights_3;
    
    att_2.push_back(att_weights_2);
    att_3.push_back(att_weights_3);
    
    Expression vector = concatenate( { context_2, context_3, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;

    last_output_embeds = lookup(cg, output_2_lookup, words);
    loss2.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

  }
  Expression losstrg = sum_batches(sum(loss2));

  
  if (USE_REG){
	Expression regularizer = reg_joint(att_1,att_3,att_2);
    return lossint + losstrg + REG_WEIGHT*regularizer;
  }
  return lossint + losstrg;
}


Expression triangle::get_loss(const vector<vector<float>>&  src_sentence, const vector<int>&  int_sentence, const vector<int>&  trg_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
    LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) {

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(src_sentence, cg);
  else
    embedded = embed_features(src_sentence, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, dec_2_lstm, encoded, int_sentence, trg_sentence, cg);
}
// Batch version
Expression triangle::get_loss(const vector<vector<vector<float>>>&  input_batch, const vector<vector<int>>&  int_batch, const vector<vector<int>>&  trg_batch, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
    LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) { 

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(input_batch , cg);
  else
    embedded = embed_features(input_batch, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, dec_2_lstm, encoded, int_batch, trg_batch, cg);
}

float triangle::train(ParameterCollection& model, const vector<vector<vector<float>>>&  source_batch, const vector<vector<int>>& int_batch, const vector<vector<int>>& target_batch, AdamTrainer& trainer, dynet::real l_scale) {
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    dec_2_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    dec_2_lstm.start_new_sequence();
    Expression loss = get_loss(source_batch, int_batch, target_batch, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg);
    float loss_value = as_scalar(cg.forward(loss)); //forward propagation
    cg.backward(loss);  //backward propagation
    trainer.update();   //update network weights
    return loss_value;
  
}


void triangle::dump_attentions(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& trg_sentence){
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    dec_2_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    dec_2_lstm.start_new_sequence();

	  vector<Expression> embedded;
    if (STACK_FEATS)
      embedded = embed_stack_features(source_sentence, cg);
    else
      embedded = embed_features(source_sentence, cg);
    vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
	
  	decode_attentions(dec_1_lstm, dec_2_lstm, encoded, int_sentence, trg_sentence, cg);
    return;
}
    
float triangle::test_dev(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& target_sentence) {
    return test_dev_bleu(model, source_sentence, int_sentence, target_sentence);
}

float triangle::test_dev_bleu(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& target_sentence) {
  ComputationGraph cg;
  vector<int> outp_sent_1, outp_sent_2;
  tie(outp_sent_1, outp_sent_2) = generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, 1, 1 ,1);
  return bleu(outp_sent_1, int_sentence) + bleu(outp_sent_2, target_sentence);  
}

void triangle::test(ParameterCollection& model, const vector<vector<float>>& source_sentence, int beamsize) {
  ComputationGraph cg;
  if (ASR_ONLY){
    auto output = generate_nbest_asr(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, beamsize, 1 ,beamsize);
    string out;
    for (auto c : std::get<0>(output) ) {
      out = out + int_d.convert(c) + " ";
    }
    out = out + " ||| ";
    cout << out << endl; 
  }
  else {
    auto output = generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, beamsize, 1 ,beamsize);
    string out;
    for (auto c : std::get<0>(output) ) {
      out = out + int_d.convert(c) + " ";
    }
    out = out + " ||| ";
    for (auto c : std::get<1>(output) ) {
      out = out + trg_d.convert(c) + " ";
    }
    cout << out << endl; 
  }
  return;
}

tuple<vector<int>, vector<int>> triangle::generate_nbest(const vector<vector<float>>& in_seq, 
                                                         LSTMBuilder& enc_fwd_lstm, 
                                                         LSTMBuilder& enc_bwd_lstm,
                                                         LSTMBuilder& enc_compress_lstm_1,
                                                         LSTMBuilder& enc_compress_lstm_2,
                                                         LSTMBuilder& dec_1_lstm, 
                                                         LSTMBuilder& dec_2_lstm, 
                                                         ComputationGraph& cg, 
                                                         int nbest_1_size, 
                                                         int nbest_2_size, 
                                                         int beamsize) {
  
  enc_fwd_lstm.new_graph(cg);
  enc_bwd_lstm.new_graph(cg);
  enc_compress_lstm_1.new_graph(cg);
  enc_compress_lstm_2.new_graph(cg);
  dec_1_lstm.new_graph(cg);
  dec_2_lstm.new_graph(cg);
  enc_fwd_lstm.start_new_sequence();
  enc_bwd_lstm.start_new_sequence();
  enc_compress_lstm_1.start_new_sequence();
  enc_compress_lstm_2.start_new_sequence();
  dec_1_lstm.start_new_sequence();
  dec_2_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(in_seq, cg);
  else
    embedded = embed_features(in_seq, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  
  // The n-best hypotheses for the first decoder
  vector<DecoderHypPtr> nbest_1, nbest_2;

  // These will hold the best combination from the int and trg decoder
  float best_overall_score = -4000000;
  vector<int> best_int, best_trg;

  // First decoder's beam search
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);
  Expression input_mat = concatenate_cols(encoded);
  Expression w1dt = w1 * input_mat;

  cg.incremental_forward(w1dt);

  vector<vector<Expression>> last_states(beamsize);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression init_vector = concatenate( { encoded.back(), last_output_embeddings});
  dec_1_lstm.add_input(init_vector);
  vector<Expression> last_s = dec_1_lstm.final_s();

  vector<int> init_sent;
  init_sent.push_back(kSOS);
  vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_s, init_sent)));

  int size_limit_ = in_seq.size() * 2;
  if (size_limit_ > 80) size_limit_ = in_seq.size();
  if (size_limit_ > 120) size_limit_ = 120;

  // Beam search 1st decoder
  for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
    // This vector will hold the best IDs
    vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

    // Iterate over the cyrrent beams and go one step forward
    for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
      // Get the current hypothesis
      DecoderHypPtr curr_hyp = curr_beam[hypid];
      // Get the current hypothesis sentence
      const vector<int>& sent = curr_hyp->GetSentence();
      // Do not expand a finished beam
      if (sent[sent_len] == kEOS) continue;

      Expression last_output_embeddings = lookup(cg, output_1_lookup, sent[sent_len]);
  
      // Perform the forward step on the decoder (after init with its last state)
      if (sent_len == 0){
        dec_1_lstm.start_new_sequence();
        dec_1_lstm.add_input(init_vector);
      }
      else
        //dec_lstm.start_new_sequence(last_states[hypid]);
        dec_1_lstm.start_new_sequence(curr_hyp->GetStates());
      //concatenate input weighted by attention and decoder lstm state
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
          
      // And now get the last softmax
      Expression out_vector = log(softmax(w * dec_1_lstm.back() + b));
      // Add length normalization
      float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
      vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));
 
    // Add unknown word penalty
      if(INT_UNK_ID >= 0) probs[INT_UNK_ID] += int_unk_log_prob_;
    // Keep the final state for the continuation of the beam
      last_states[hypid] = dec_1_lstm.final_s();
      // Find the best IDs
      for(int wid = 0; wid < (int)probs.size(); wid++) {
        // The new score will be the current score + the softmax score
        dynet::real my_score = curr_hyp->GetScore() + probs[wid];
        // Now go through the beams from bottom to the beginning
        // and only keep the best <beamsize>
        int bid;
        for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
          next_beam_id[bid] = next_beam_id[bid-1];
        next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
      }
    }
    // Create the new hypotheses
    vector<DecoderHypPtr> next_beam;
    for(int i = 0; i < beamsize; i++) {
      dynet::real score = get<0>(next_beam_id[i]);
      int hypid = get<1>(next_beam_id[i]);
      int wid = get<2>(next_beam_id[i]);
      // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
      if(hypid == -1) break;
      // Add the last word to the sentence
      vector<int> next_sent = curr_beam[hypid]->GetSentence();
      next_sent.push_back(wid);
      DecoderHypPtr hyp(new DecoderHyp(score, last_states[hypid], next_sent));     
      // If we are done, add it to the nbest list
      if(wid == kEOS || sent_len == size_limit_) 
        nbest_1.push_back(hyp);
      // Add it do the next beams to be expanded
      next_beam.push_back(hyp);
    }

    // Substitute beams with the next ones 
    curr_beam = next_beam;
    // Check if we're done with search
    if(nbest_1.size() != 0) {
      sort(nbest_1.begin(), nbest_1.end());
      // trim to top n options
      if(nbest_1.size() > nbest_1_size)
        nbest_1.resize(nbest_1_size);
      // If we have no more beams to expand
      if(nbest_1.size() == nbest_1_size && (next_beam.size() == 0 || (*nbest_1.rbegin())->GetScore() >= next_beam[0]->GetScore()))
        break;
    }
  }

  // Now beam search on the second decoder for each of the candidates form the first decoder
  for (int kk = 0; kk < nbest_1.size(); kk++){
    float sent_1_score = nbest_1[kk]->GetScore();
    // Get the sequence of states from the first decoder
    dec_1_lstm.start_new_sequence();
    dec_1_lstm.add_input(init_vector);
    vector<Expression> decoder_1_states;
    for (auto c : nbest_1[kk]->GetSentence()) {
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
      decoder_1_states.push_back(dec_1_lstm.back());
      last_output_embeddings = lookup(cg, output_1_lookup, c);
    }
    // Second decoder
    dec_2_lstm.new_graph(cg);
    dec_2_lstm.start_new_sequence();
    Expression w_2 = parameter(cg, decoder_2_w);
    Expression b_2 = parameter(cg, decoder_2_b);
    Expression w1_2 = parameter(cg, attention_2_w1);
    Expression w1_3 = parameter(cg, attention_3_w1);
    Expression input_mat_2 = concatenate_cols(decoder_1_states);
    Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
    Expression w1dt_2 = w1_2 * input_mat;
    Expression w1dt_3 = w1_3 * input_mat_2;
    cg.incremental_forward(w1dt_2);
    cg.incremental_forward(w1dt_3);
    vector<vector<Expression>> last_2_states(beamsize);
    Expression init_2_vector = concatenate( { encoded.back(), decoder_1_states.back(), last_output_embeds } );
    dec_2_lstm.add_input( init_2_vector);
    vector<Expression> last_2_s = dec_2_lstm.final_s();
    vector<int> init_sent;
    init_sent.push_back(kSOS);
    vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_2_s, init_sent))); 
    int size_limit_ = in_seq.size() * 2; 
    if (size_limit_ > 80) size_limit_ = in_seq.size();
    if (size_limit_ > 120) size_limit_ = 120;
    
    for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
      // This vector will hold the best IDs
      vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));
      // Iterate over the cyrrent beams and go one step forward
      for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
        // Get the current hypothesis
        DecoderHypPtr curr_hyp = curr_beam[hypid];
        // Get the current hypothesis sentence
        const vector<int>& sent = curr_hyp->GetSentence();
        // Do not expand a finished beam
        if (sent[sent_len] == kEOS) continue;

        Expression last_output_embeds = lookup(cg, output_2_lookup, sent[sent_len]);

        // Perform the forward step on the decoder (after init with its last state)
        if (sent_len == 0){
          dec_2_lstm.start_new_sequence();
          dec_2_lstm.add_input(init_2_vector);
        }
        else
          dec_2_lstm.start_new_sequence(curr_hyp->GetStates());

        //concatenate input weighted by attention and decoder lstm state
        Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
        Expression context_2 = input_mat * att_weights_2;
        Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
        Expression context_3 = input_mat_2 * att_weights_3;
        
        Expression in_vector = concatenate( { context_2, context_3, last_output_embeds });

        dec_2_lstm.add_input(in_vector);

        // And now get the last softmax
        Expression out_vector = log(softmax(w_2 * dec_2_lstm.back() + b_2));
        // Add length normalization
        float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
        std::vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));

        // Add unknown word penalty
        if(TRG_UNK_ID >= 0) probs[TRG_UNK_ID] += trg_unk_log_prob_;
        // Keep the final state for the continuation of the beam
        last_2_states[hypid] = dec_2_lstm.final_s();
        // Find the best IDs
        for(int wid = 0; wid < (int)probs.size(); wid++) {
          // The new score will be the current score + the softmax score
          dynet::real my_score = curr_hyp->GetScore() + probs[wid];
          // Now go through the beams from bottom to the beginning
          // and only keep the best <beamsize>
          int bid;
          for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
            next_beam_id[bid] = next_beam_id[bid-1];
          next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
        }
      }
      // Create the new hypotheses
      vector<DecoderHypPtr> next_beam;
      for(int i = 0; i < beamsize; i++) {
        dynet::real score = get<0>(next_beam_id[i]);
        int hypid = get<1>(next_beam_id[i]);
        int wid = get<2>(next_beam_id[i]);
        if(hypid == -1) break;
        // Add the last word to the sentence
        vector<int> next_sent = curr_beam[hypid]->GetSentence();
        next_sent.push_back(wid);
        DecoderHypPtr hyp(new DecoderHyp(score, last_2_states[hypid], next_sent));     
        // If we are done, add it to the nbest list
        if(wid == kEOS || sent_len == size_limit_) 
          nbest_2.push_back(hyp);
        // Add it do the next beams to be expanded
        next_beam.push_back(hyp);
      }

      // Substitute beams with the next ones 
      curr_beam = next_beam;
      // Check if we're done with search
      if(nbest_2.size() != 0) {
        sort(nbest_2.begin(), nbest_2.end());
        // trim to top n options
        if(nbest_2.size() > nbest_2_size)
          nbest_2.resize(nbest_2_size);
        // If we have no more beams to expand
        if (nbest_2.size() == nbest_2_size && (next_beam.size() == 0 || (*nbest_2.rbegin())->GetScore() >= next_beam[0]->GetScore()))
          break;
      }
    } 
    if (sent_1_score + nbest_2[0]->GetScore() > best_overall_score){
       best_overall_score = sent_1_score + nbest_2[0]->GetScore();
       best_int = nbest_1[kk]->GetSentence();
       best_trg = nbest_2[0]->GetSentence();
    }
  }
  return make_tuple(best_int, best_trg);
}

tuple<vector<int>, vector<int>> triangle::generate_nbest_asr(const vector<vector<float>>& in_seq, 
                                                         LSTMBuilder& enc_fwd_lstm, 
                                                         LSTMBuilder& enc_bwd_lstm,
                                                         LSTMBuilder& enc_compress_lstm_1,
                                                         LSTMBuilder& enc_compress_lstm_2,
                                                         LSTMBuilder& dec_1_lstm, 
                                                         LSTMBuilder& dec_2_lstm, 
                                                         ComputationGraph& cg, 
                                                         int nbest_1_size, 
                                                         int nbest_2_size, 
                                                         int beamsize) {
  
  enc_fwd_lstm.new_graph(cg);
  enc_bwd_lstm.new_graph(cg);
  enc_compress_lstm_1.new_graph(cg);
  enc_compress_lstm_2.new_graph(cg);
  dec_1_lstm.new_graph(cg);
  dec_2_lstm.new_graph(cg);
  enc_fwd_lstm.start_new_sequence();
  enc_bwd_lstm.start_new_sequence();
  enc_compress_lstm_1.start_new_sequence();
  enc_compress_lstm_2.start_new_sequence();
  dec_1_lstm.start_new_sequence();
  dec_2_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(in_seq, cg);
  else
    embedded = embed_features(in_seq, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  
  // The n-best hypotheses for the first decoder
  vector<DecoderHypPtr> nbest_1, nbest_2;

  // These will hold the best sentences from the int and trg decoder
  vector<int> best_int, best_trg;

  // First decoder's beam search
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);
  Expression input_mat = concatenate_cols(encoded);
  Expression w1dt = w1 * input_mat;

  cg.incremental_forward(w1dt);

  vector<vector<Expression>> last_states(beamsize);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression init_vector = concatenate( { encoded.back(), last_output_embeddings});
  dec_1_lstm.add_input(init_vector);
  vector<Expression> last_s = dec_1_lstm.final_s();

  vector<int> init_sent;
  init_sent.push_back(kSOS);
  vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_s, init_sent)));

  int size_limit_ = in_seq.size() * 2;
  if (size_limit_ > 80) size_limit_ = in_seq.size();
  if (size_limit_ > 120) size_limit_ = 120;

  // Beam search 1st decoder
  for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
    // This vector will hold the best IDs
    vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

    // Iterate over the cyrrent beams and go one step forward
    for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
      // Get the current hypothesis
      DecoderHypPtr curr_hyp = curr_beam[hypid];
      // Get the current hypothesis sentence
      const vector<int>& sent = curr_hyp->GetSentence();
    // Do not expand a finished beam
      if (sent[sent_len] == kEOS) continue;

      Expression last_output_embeddings = lookup(cg, output_1_lookup, sent[sent_len]);
  
      // Perform the forward step on the decoder (after init with its last state)
      if (sent_len == 0){
        dec_1_lstm.start_new_sequence();
        dec_1_lstm.add_input(init_vector);
      }
      else
        //dec_lstm.start_new_sequence(last_states[hypid]);
        dec_1_lstm.start_new_sequence(curr_hyp->GetStates());
      //concatenate input weighted by attention and decoder lstm state
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      //att_1.push_back(att_weights_1);         
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
          
      // And now get the last softmax
      Expression out_vector = log(softmax(w * dec_1_lstm.back() + b));
      // Add length normalization
      float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
      vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));
 
    // Add unknown word penalty
      if(INT_UNK_ID >= 0) probs[INT_UNK_ID] += int_unk_log_prob_;
    // Keep the final state for the continuation of the beam
      last_states[hypid] = dec_1_lstm.final_s();
      // Find the best IDs
      for(int wid = 0; wid < (int)probs.size(); wid++) {
        // The new score will be the current score + the softmax score
        dynet::real my_score = curr_hyp->GetScore() + probs[wid];
        // Now go through the beams from bottom to the beginning
        // and only keep the best <beamsize>
        int bid;
        for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
          next_beam_id[bid] = next_beam_id[bid-1];
        next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
      }
    }
    // Create the new hypotheses
    vector<DecoderHypPtr> next_beam;
    for(int i = 0; i < beamsize; i++) {
      dynet::real score = get<0>(next_beam_id[i]);
      int hypid = get<1>(next_beam_id[i]);
      int wid = get<2>(next_beam_id[i]);
      // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
      if(hypid == -1) break;
      // Add the last word to the sentence
      vector<int> next_sent = curr_beam[hypid]->GetSentence();
      next_sent.push_back(wid);
      DecoderHypPtr hyp(new DecoderHyp(score, last_states[hypid], next_sent));     
      // If we are done, add it to the nbest list
      if(wid == kEOS || sent_len == size_limit_) 
        nbest_1.push_back(hyp);
      // Add it do the next beams to be expanded
      next_beam.push_back(hyp);
    }

    // Substitute beams with the next ones 
    curr_beam = next_beam;
    // Check if we're done with search
    if(nbest_1.size() != 0) {
      sort(nbest_1.begin(), nbest_1.end());
      // trim to top n options
      if(nbest_1.size() > nbest_1_size)
        nbest_1.resize(nbest_1_size);
      // If we have no more beams to expand
      if(nbest_1.size() == nbest_1_size && (next_beam.size() == 0 || (*nbest_1.rbegin())->GetScore() >= next_beam[0]->GetScore()))
        break;
        // return nbest_1
    }
  }
  best_int = nbest_1[0]->GetSentence();
  return make_tuple(best_int, best_trg);
}


/* Definitions for cascade model
 *
 *
 */

void cascade::initialize(ParameterCollection& model) {

  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  dec_2_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  dec_2_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  output_2_lookup = model.add_lookup_parameters(TRG_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  attention_3_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE  });
  attention_3_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_3_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });

  decoder_2_w = model.add_parameters( { TRG_VOCAB_SIZE, STATE_SIZE });
  decoder_2_b = model.add_parameters( { TRG_VOCAB_SIZE });
  
}

void cascade::initialize_partial(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });

}

void cascade::initialize_extra(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  dec_2_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);  
  dec_2_lstm.set_dropout(DROPOUT);

  output_2_lookup = model.add_lookup_parameters(TRG_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_3_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE  });
  attention_3_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_3_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_2_w = model.add_parameters( { TRG_VOCAB_SIZE, STATE_SIZE });
  decoder_2_b = model.add_parameters( { TRG_VOCAB_SIZE });
}


Expression cascade::attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_1_w2);
  Expression v = parameter(cg, attention_1_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}


Expression cascade::attend_3(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_3_w2);
  Expression v = parameter(cg, attention_3_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}

// TODO(aanastas): This is not batched, most likely
Expression cascade::reg_joint(vector<Expression>& att_21, vector<Expression>& att_32, ComputationGraph& cg){
  vector<Expression> reg_loss;
  unsigned int N = att_32.size();

  Expression a21 = concatenate_cols(att_21);
  Expression a32 = concatenate_cols(att_32);
  
  Expression a31_pivot = a21 * a32;
  for (int k=0; k < N; k++){
    // TODO: Is there a better way to create the I (identity) matrix?
    // (Here we create the k-th row each time)
    vector<dynet::real> identity_row(N,0);
      identity_row.at(k) = 1;

    Expression dist = squared_distance(pick(a31_pivot,k,1), input(cg, {N}, identity_row) );
    reg_loss.push_back(dist);      
  }

  return sum_batches(sum(reg_loss));
}

// One sentence version
Expression cascade::decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  for (int c : trg_sentence) {
    embeddings_2.push_back(c);
  }
  embeddings_2.push_back(kEOS);

  unsigned N = encoded.size();

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  vector<dynet::real> x_values(STATE_SIZE );
  
  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(pick_range(att_weights_1, 1, N));
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }
  Expression lossint = sum(loss);

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_3 = parameter(cg, attention_3_w1);

  vector<Expression> loss2;

  Expression input_mat_2 = concatenate_cols(decoder_1_states);
  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_3 = w1_3 * input_mat_2;

  dec_2_lstm.add_input( concatenate( { decoder_1_states.back(), last_output_embeds } ) );

  vector<Expression> att_3;

  //cout << endl << "Attention 2" << endl;
  for (int c: embeddings_2){
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
    Expression context_3 = input_mat_2 * att_weights_3;
    
    att_3.push_back(att_weights_3);
    
    Expression vector = concatenate( { context_3, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;
    Expression probs = softmax(out_vector);
    last_output_embeds = lookup(cg, output_2_lookup, c);
    loss2.push_back(-log(pick(probs, c)));
  }
  Expression losstrg = sum(loss2);

  if (USE_REG){
    Expression regularizer = reg_joint(att_1, att_3, cg);
    return lossint + losstrg + REG_WEIGHT*regularizer;
  }
  return lossint + losstrg;
}

// Version that dumps attentions

void cascade::decode_attentions(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  for (int c : trg_sentence) {
    embeddings_2.push_back(c);
  }
  embeddings_2.push_back(kEOS);

  unsigned N = encoded.size();

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;
  cout << "Attention 1" << endl;
  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;
    cout << as_vector(att_weights_1.value()) << endl;
    
    att_1.push_back(pick_range(att_weights_1, 1, N));
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_3 = parameter(cg, attention_3_w1);

  vector<Expression> loss2;

  Expression input_mat_2 = concatenate_cols(decoder_1_states);
  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_3 = w1_3 * input_mat_2;

  dec_2_lstm.add_input( concatenate( { decoder_1_states.back(), last_output_embeds } ) );

  vector<Expression> att_3;

  cout << endl << "Attention 3" << endl;
  for (int c: embeddings_2){
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
    Expression context_3 = input_mat_2 * att_weights_3;
    
    att_3.push_back(att_weights_3);
    cout << as_vector(att_weights_3.value()) << endl;
    
    Expression vector = concatenate( { context_3, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;
    Expression probs = softmax(out_vector);
    last_output_embeds = lookup(cg, output_2_lookup, c);
    loss2.push_back(-log(pick(probs, c)));
  }
  return;
}


// Batch version
Expression cascade::decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg) {

  // The first output sentence should be the longest
  const unsigned max_len_1 = int_batch[0].size();
  const unsigned max_len_2 = trg_batch[0].size();

  unsigned N = encoded.size();
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  // Initialize all sentences with SOS and get them as input
  vector<unsigned> words(int_batch.size(), kSOS);
  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for(int t = 0; t <= max_len_1; t++) {
    for(size_t i = 0; i < int_batch.size(); i++)
      words[i] = (t < int_batch[i].size() ? int_batch[i][t] : kEOS);

    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    // We are ignoring the attention on the <kSOS> symbol cause it is fed in the other decoders.
    // That way, att_2 * att_1 will be a square matrix.
    att_1.push_back(pick_range(att_weights_1, 1, N));
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    last_output_embeddings = lookup(cg, output_1_lookup, words);
    loss.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }

  Expression lossint = sum_batches(sum(loss));
  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_3 = parameter(cg, attention_3_w1);

  vector<Expression> loss2;

  Expression input_mat_2 = concatenate_cols(decoder_1_states);

  vector<unsigned> words_2(trg_batch.size(), kSOS);

  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_3 = w1_3 * input_mat_2;

  dec_2_lstm.add_input( concatenate( { decoder_1_states.back(), last_output_embeds } ) );
  vector<Expression> att_3;

  for(int t = 0; t <= max_len_2; t++) {
    for(size_t i = 0; i < trg_batch.size(); i++)
      words[i] = (t < trg_batch[i].size() ? trg_batch[i][t] : kEOS);

    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
    Expression context_3 = input_mat_2 * att_weights_3;
    
    att_3.push_back(att_weights_3);
    
    Expression vector = concatenate( { context_3, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;

    last_output_embeds = lookup(cg, output_2_lookup, words);
    loss2.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

  }
  Expression losstrg = sum_batches(sum(loss2));
  
  if (USE_REG){
  Expression regularizer = reg_joint(att_1, att_3, cg);
    return lossint + losstrg + REG_WEIGHT*regularizer;
  }
  return lossint + losstrg;
}


Expression cascade::get_loss(const vector<vector<float>>&  src_sentence, const vector<int>&  int_sentence, const vector<int>&  trg_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
   LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) {

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(src_sentence, cg);
  else
    embedded = embed_features(src_sentence, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, dec_2_lstm, encoded, int_sentence, trg_sentence, cg);
}
// Batch version
Expression cascade::get_loss(const vector<vector<vector<float>>>&  input_batch, const vector<vector<int>>&  int_batch, const vector<vector<int>>&  trg_batch, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
   LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) {

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(input_batch, cg);
  else
    embedded = embed_features(input_batch, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, dec_2_lstm, encoded, int_batch, trg_batch, cg);
}

float cascade::train(ParameterCollection& model, const vector<vector<vector<float>>>&  source_batch, const vector<vector<int>>& int_batch, const vector<vector<int>>& target_batch, AdamTrainer& trainer, dynet::real l_scale) {
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    dec_2_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    dec_2_lstm.start_new_sequence();
    Expression loss = get_loss(source_batch, int_batch, target_batch, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg);
    
    float loss_value = as_scalar(cg.forward(loss)); //forward propagation
    cg.backward(loss);  //backward propagation
    trainer.update();   //update network weights
    return loss_value;
}

void cascade::dump_attentions(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& trg_sentence){
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    dec_2_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    dec_2_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(source_sentence, cg);
  else
    embedded = embed_features(source_sentence, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  
  decode_attentions(dec_1_lstm, dec_2_lstm, encoded, int_sentence, trg_sentence, cg);

  return;
}


float cascade::test_dev(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& target_sentence) {
    return test_dev_bleu(model, source_sentence, int_sentence, target_sentence);
}

float cascade::test_dev_bleu(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& target_sentence) {
  ComputationGraph cg;
  vector<int> outp_sent_1, outp_sent_2;
  tie(outp_sent_1, outp_sent_2) = generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, 1, 1 ,1);

  return bleu(outp_sent_1, int_sentence) + bleu(outp_sent_2, target_sentence);  
}

void cascade::test(ParameterCollection& model, const vector<vector<float>>& source_sentence, int beamsize) {
  ComputationGraph cg;
  auto output =  generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, beamsize, 1 ,beamsize);
  string out;
  for (auto c : std::get<0>(output) ) {
    out = out + int_d.convert(c) + " ";
  }
  out = out + " ||| ";
  for (auto c : std::get<1>(output) ) {
    out = out + trg_d.convert(c) + " ";
  }
  cout << out << endl; 
  return;
}


tuple<vector<int>,vector<int>> cascade::generate_nbest(const vector<vector<float>>& in_seq, 
      LSTMBuilder& enc_fwd_lstm, 
      LSTMBuilder& enc_bwd_lstm, 
      LSTMBuilder& enc_compress_lstm_1,
      LSTMBuilder& enc_compress_lstm_2,
      LSTMBuilder& dec_1_lstm, 
      LSTMBuilder& dec_2_lstm, 
      ComputationGraph& cg, 
      int nbest_1_size, 
      int nbest_2_size, 
      int beamsize) {

  enc_fwd_lstm.new_graph(cg);
  enc_bwd_lstm.new_graph(cg);
  enc_compress_lstm_1.new_graph(cg);
  enc_compress_lstm_2.new_graph(cg);
  dec_1_lstm.new_graph(cg);
  dec_2_lstm.new_graph(cg);
  enc_fwd_lstm.start_new_sequence();
  enc_bwd_lstm.start_new_sequence();
  enc_compress_lstm_1.start_new_sequence();
  enc_compress_lstm_2.start_new_sequence();
  dec_1_lstm.start_new_sequence();
  dec_2_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(in_seq, cg);
  else
    embedded = embed_features(in_seq, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  
  // The n-best hypotheses for the first decoder
  vector<DecoderHypPtr> nbest_1, nbest_2;

  // These will hold the best combination from the int and trg decoder
  float best_overall_score = -4000000;
  vector<int> best_int, best_trg;

  // First decoder's beam search
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);
  Expression input_mat = concatenate_cols(encoded);
  Expression w1dt = w1 * input_mat;

  cg.incremental_forward(w1dt);

  vector<vector<Expression>> last_states(beamsize);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression init_vector = concatenate( { encoded.back(), last_output_embeddings});
  dec_1_lstm.add_input(init_vector);
  vector<Expression> last_s = dec_1_lstm.final_s();

  vector<int> init_sent;
  init_sent.push_back(kSOS);
  vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_s, init_sent)));

  int size_limit_ = in_seq.size() * 2;
  if (size_limit_ > 80) size_limit_ = in_seq.size();

  // Beam search 1st decoder
  for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
    // This vector will hold the best IDs
    vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

    // Iterate over the cyrrent beams and go one step forward
    for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
      // Get the current hypothesis
      DecoderHypPtr curr_hyp = curr_beam[hypid];
      // Get the current hypothesis sentence
      const vector<int>& sent = curr_hyp->GetSentence();
      
      // Do not expand a finished beam
      if (sent[sent_len] == kEOS) continue;

      Expression last_output_embeddings = lookup(cg, output_1_lookup, sent[sent_len]);
  
      // Perform the forward step on the decoder (after init with its last state)
      if (sent_len == 0){
        dec_1_lstm.start_new_sequence();
        dec_1_lstm.add_input(init_vector);
      }
      else
        dec_1_lstm.start_new_sequence(curr_hyp->GetStates());

      //concatenate input weighted by attention and decoder lstm state
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
          
      // And now get the last softmax
      Expression out_vector = log(softmax(w * dec_1_lstm.back() + b));
      // Add length normalization
      float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
      vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));


    // Add unknown word penalty
      if(INT_UNK_ID >= 0) probs[INT_UNK_ID] += int_unk_log_prob_;
    // Keep the final state for the continuation of the beam
      last_states[hypid] = dec_1_lstm.final_s();
      // Find the best IDs
      for(int wid = 0; wid < (int)probs.size(); wid++) {
        // The new score will be the current score + the softmax score
        dynet::real my_score = curr_hyp->GetScore() + probs[wid];
        // Now go through the beams from bottom to the beginning
        // and only keep the best <beamsize>
        int bid;
        for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
          next_beam_id[bid] = next_beam_id[bid-1];
        next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
      }
    }
    // Create the new hypotheses
    vector<DecoderHypPtr> next_beam;
    for(int i = 0; i < beamsize; i++) {
      dynet::real score = get<0>(next_beam_id[i]);
      int hypid = get<1>(next_beam_id[i]);
      int wid = get<2>(next_beam_id[i]);
      // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
      if(hypid == -1) break;
      // Add the last word to the sentence
      vector<int> next_sent = curr_beam[hypid]->GetSentence();
      next_sent.push_back(wid);
      DecoderHypPtr hyp(new DecoderHyp(score, last_states[hypid], next_sent));     
      // If we are done, add it to the nbest list
      if(wid == kEOS || sent_len == size_limit_) 
        nbest_1.push_back(hyp);
      // Add it do the next beams to be expanded
      next_beam.push_back(hyp);
    }

    // Substitute beams with the next ones 
    curr_beam = next_beam;
    // Check if we're done with search
    if(nbest_1.size() != 0) {
      sort(nbest_1.begin(), nbest_1.end());
      // trim to top n options
      if(nbest_1.size() > nbest_1_size)
        nbest_1.resize(nbest_1_size);
      // If we have no more beams to expand
      if(nbest_1.size() == nbest_1_size && (next_beam.size() == 0 || (*nbest_1.rbegin())->GetScore() >= next_beam[0]->GetScore()))
        break;
    }
  }

  for (int kk = 0; kk < nbest_1.size(); kk++){
    float sent_1_score = nbest_1[kk]->GetScore();
    // Get the sequence of states from the first decoder
    dec_1_lstm.start_new_sequence();
    dec_1_lstm.add_input(init_vector);
    vector<Expression> decoder_1_states;
    for (auto c : nbest_1[kk]->GetSentence()) {
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
      decoder_1_states.push_back(dec_1_lstm.back());
      last_output_embeddings = lookup(cg, output_1_lookup, c);
    }

    // Second decoder
    dec_2_lstm.new_graph(cg);
    dec_2_lstm.start_new_sequence();
    Expression w_2 = parameter(cg, decoder_2_w);
    Expression b_2 = parameter(cg, decoder_2_b);
    Expression w1_3 = parameter(cg, attention_3_w1);
    Expression input_mat_2 = concatenate_cols(decoder_1_states);
    Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
    Expression w1dt_3 = w1_3 * input_mat_2;
    cg.incremental_forward(w1dt_3);
    vector<vector<Expression>> last_2_states(beamsize);
    
    Expression init_2_vector = concatenate( { decoder_1_states.back(), last_output_embeds } );
    dec_2_lstm.add_input( init_2_vector);
    vector<Expression> last_2_s = dec_2_lstm.final_s();
    vector<int> init_sent;
    init_sent.push_back(kSOS);
    vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_2_s, init_sent))); 
    int size_limit_ = in_seq.size() * 2; 
  if (size_limit_ > 80) size_limit_ = in_seq.size();


    for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
      // This vector will hold the best IDs
      vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

      // Iterate over the cyrrent beams and go one step forward
      for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
        // Get the current hypothesis
        DecoderHypPtr curr_hyp = curr_beam[hypid];
        // Get the current hypothesis sentence
        const vector<int>& sent = curr_hyp->GetSentence();
        //if (sent_len != 0 && *sent.rbegin() == 0) continue;
        // Do not expand a finished beam
        if (sent[sent_len] == kEOS) continue;

        Expression last_output_embeds = lookup(cg, output_2_lookup, sent[sent_len]);

        // Perform the forward step on the decoder (after init with its last state)
        if (sent_len == 0){
          dec_2_lstm.start_new_sequence();
          dec_2_lstm.add_input(init_2_vector);
        }
        else
          dec_2_lstm.start_new_sequence(curr_hyp->GetStates());

        //concatenate input weighted by attention and decoder lstm state
        Expression att_weights_3 = attend_3(dec_2_lstm, w1dt_3, cg);
        Expression context_3 = input_mat_2 * att_weights_3;
        
        Expression in_vector = concatenate( { context_3, last_output_embeds });

        dec_2_lstm.add_input(in_vector);

        // And now get the last softmax
        Expression out_vector = log(softmax(w_2 * dec_2_lstm.back() + b_2));
        // Add length normalization
        float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
        vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));
          

        // Add unknown word penalty
        if(TRG_UNK_ID >= 0) probs[TRG_UNK_ID] += trg_unk_log_prob_;
        // Keep the final state for the continuation of the beam
        last_2_states[hypid] = dec_2_lstm.final_s();
        // Find the best IDs
        for(int wid = 0; wid < (int)probs.size(); wid++) {
          // The new score will be the current score + the softmax score
          dynet::real my_score = curr_hyp->GetScore() + probs[wid];
          // Now go through the beams from bottom to the beginning
          // and only keep the best <beamsize>
          int bid;
          for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
            next_beam_id[bid] = next_beam_id[bid-1];
          next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
        }
      }
      // Create the new hypotheses
      vector<DecoderHypPtr> next_beam;
      for(int i = 0; i < beamsize; i++) {
        dynet::real score = get<0>(next_beam_id[i]);
        int hypid = get<1>(next_beam_id[i]);
        int wid = get<2>(next_beam_id[i]);
        // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
        if(hypid == -1) break;
        // Add the last word to the sentence
        vector<int> next_sent = curr_beam[hypid]->GetSentence();
        next_sent.push_back(wid);
        DecoderHypPtr hyp(new DecoderHyp(score, last_2_states[hypid], next_sent));     
        // If we are done, add it to the nbest list
        if(wid == kEOS || sent_len == size_limit_) 
          nbest_2.push_back(hyp);
        // Add it do the next beams to be expanded
        next_beam.push_back(hyp);
      }

      // Substitute beams with the next ones 
      curr_beam = next_beam;
      // Check if we're done with search
      if(nbest_2.size() != 0) {
        sort(nbest_2.begin(), nbest_2.end());
        // trim to top n options
        if(nbest_2.size() > nbest_2_size)
          nbest_2.resize(nbest_2_size);
        // If we have no more beams to expand
        if (nbest_2.size() == nbest_2_size && (next_beam.size() == 0 || (*nbest_2.rbegin())->GetScore() >= next_beam[0]->GetScore()))
          break;
          // return nbest_2
      }
    } 

    if (sent_1_score + nbest_2[0]->GetScore() > best_overall_score){
       best_overall_score = sent_1_score + nbest_2[0]->GetScore();
       best_int = nbest_1[kk]->GetSentence();
       best_trg = nbest_2[0]->GetSentence();
    }
  }
  return make_tuple(best_int, best_trg);
}



/* Definitions for simplemultitask model
 *
 *
 */

void simplemultitask::initialize(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  
  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  dec_2_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  dec_2_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  output_2_lookup = model.add_lookup_parameters(TRG_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  attention_2_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_2_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_2_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });

  decoder_2_w = model.add_parameters( { TRG_VOCAB_SIZE, STATE_SIZE });
  decoder_2_b = model.add_parameters( { TRG_VOCAB_SIZE });
}
void simplemultitask::initialize_partial(ParameterCollection& model) {

  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  
  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });

}

void simplemultitask::initialize_extra(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;

  dec_2_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  dec_2_lstm.set_dropout(DROPOUT);
  
  output_2_lookup = model.add_lookup_parameters(TRG_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  attention_2_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_2_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_2_v = model.add_parameters( { 1, ATTENTION_SIZE });

  decoder_2_w = model.add_parameters( { TRG_VOCAB_SIZE, STATE_SIZE });
  decoder_2_b = model.add_parameters( { TRG_VOCAB_SIZE });  

}


Expression simplemultitask::attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_1_w2);
  Expression v = parameter(cg, attention_1_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}


Expression simplemultitask::attend_2(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_2_w2);
  Expression v = parameter(cg, attention_2_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}


// One sentence version
Expression simplemultitask::decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  for (int c : trg_sentence) {
    embeddings_2.push_back(c);
  }
  embeddings_2.push_back(kEOS);

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }
  Expression lossint = sum(loss);

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);
  
  vector<Expression> loss2;

  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;
  
  dec_2_lstm.add_input( concatenate( { encoded.back(), last_output_embeds } ) );

  vector<Expression> att_2;
  

  for (int c: embeddings_2){
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
    Expression context_2 = input_mat * att_weights_2;
    
    att_2.push_back(att_weights_2);
    
    Expression vector = concatenate( { context_2, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;
    Expression probs = softmax(out_vector);
    last_output_embeds = lookup(cg, output_2_lookup, c);
    loss2.push_back(-log(pick(probs, c)));
  }
  Expression losstrg = sum(loss2);

  return lossint + losstrg;
}


//Version that dumps attentions
void simplemultitask::decode_attentions(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  for (int c : trg_sentence) {
    embeddings_2.push_back(c);
  }
  embeddings_2.push_back(kEOS);

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;
  cout << "Attention 1" << endl;

  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    cout << as_vector(att_weights_1.value()) << endl;
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);
  
  vector<Expression> loss2;

  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;
  
  dec_2_lstm.add_input( concatenate( { encoded.back(), last_output_embeds } ) );

  vector<Expression> att_2;
  cout << endl << "Attention 2" << endl;

  for (int c: embeddings_2){
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
    Expression context_2 = input_mat * att_weights_2;
    
    att_2.push_back(att_weights_2);
    cout << as_vector(att_weights_2.value()) << endl;
    Expression vector = concatenate( { context_2, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;
    Expression probs = softmax(out_vector);
    last_output_embeds = lookup(cg, output_2_lookup, c);
    loss2.push_back(-log(pick(probs, c)));
  }

  return;
}



// Batch version
Expression simplemultitask::decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg) {

  // The first output sentence should be the longest
  const unsigned max_len_1 = int_batch[0].size();
  
  vector<int> embeddings_1;
  vector<int> embeddings_2;

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  // Initialize all sentences with SOS and get them as input
  vector<unsigned> words(int_batch.size(), kSOS);
  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;
  
  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for(int t = 0; t <= max_len_1; t++) {
    for(size_t i = 0; i < int_batch.size(); i++)
      words[i] = (t < int_batch[i].size() ? int_batch[i][t] : kEOS);

    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    last_output_embeddings = lookup(cg, output_1_lookup, words);
    loss.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

    // Push the last hidden state for future attention
    decoder_1_states.push_back(dec_1_lstm.back());
  }

  Expression lossint = sum_batches(sum(loss));

  // Second decoder
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);

  vector<Expression> loss2;

  vector<unsigned> words_2(trg_batch.size(), kSOS);

  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;

  dec_2_lstm.add_input( concatenate( { encoded.back(), last_output_embeds } ) );

  vector<Expression> att_2;

  for(int t = 0; t <= max_len_1; t++) {
    for(size_t i = 0; i < trg_batch.size(); i++)
      words[i] = (t < trg_batch[i].size() ? trg_batch[i][t] : kEOS);

    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
    Expression context_2 = input_mat * att_weights_2;
    
    att_2.push_back(att_weights_2);
    
    Expression vector = concatenate( { context_2, last_output_embeds });
    dec_2_lstm.add_input(vector);
    Expression out_vector = w_2 * dec_2_lstm.back() + b_2;

    last_output_embeds = lookup(cg, output_2_lookup, words);
    loss2.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

  }
  Expression losstrg = sum_batches(sum(loss2));

  return lossint + losstrg;
}

Expression simplemultitask::get_loss(const vector<vector<float>>&  src_sentence, const vector<int>&  int_sentence, const vector<int>&  trg_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
    LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) {

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(src_sentence, cg);
  else
    embedded = embed_features(src_sentence, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, dec_2_lstm, encoded, int_sentence, trg_sentence, cg);
}
// Batch version
Expression simplemultitask::get_loss(const vector<vector<vector<float>>>&  input_batch, const vector<vector<int>>&  int_batch, const vector<vector<int>>&  trg_batch, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
    LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) {

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(input_batch, cg);
  else
    embedded = embed_features(input_batch, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, dec_2_lstm, encoded, int_batch, trg_batch, cg);
}

float simplemultitask::train(ParameterCollection& model, const vector<vector<vector<float>>>&  source_batch, const vector<vector<int>>& int_batch, const vector<vector<int>>& target_batch, AdamTrainer& trainer, dynet::real l_scale) {
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    dec_2_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    dec_2_lstm.start_new_sequence();
    Expression loss = get_loss(source_batch, int_batch, target_batch, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg);
    float loss_value = as_scalar(cg.forward(loss)); //forward propagation
    cg.backward(loss);  //backward propagation
    trainer.update();   //update network weights
    return loss_value;
  
}

void simplemultitask::dump_attentions(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& trg_sentence){
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    dec_2_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    dec_2_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(source_sentence, cg);
  else
    embedded = embed_features(source_sentence, cg);
	vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
	
	decode_attentions(dec_1_lstm, dec_2_lstm, encoded, int_sentence, trg_sentence, cg);

	return;
}
    

float simplemultitask::test_dev(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& target_sentence) {
    return test_dev_bleu(model, source_sentence, int_sentence, target_sentence);
  
}

void simplemultitask::test(ParameterCollection& model, const vector<vector<float>>& source_sentence, int beamsize) {
  ComputationGraph cg;
  auto output =  generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, beamsize, 1 ,beamsize);
  string out;
  for (auto c : std::get<0>(output) ) {
    out = out + int_d.convert(c) + " ";
  }
  out = out + " ||| ";
  for (auto c : std::get<1>(output) ) {
    out = out + trg_d.convert(c) + " ";
  }
  cout << out << endl; 
  return;
}

float simplemultitask::test_dev_bleu(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence, const vector<int>& target_sentence) {
  ComputationGraph cg;
  vector<int> outp_sent_1, outp_sent_2;
  tie(outp_sent_1, outp_sent_2) = generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, dec_2_lstm, cg, 1, 1 ,1);

  return bleu(outp_sent_1, int_sentence) + bleu(outp_sent_2, target_sentence);
}


std::tuple<vector<int>,vector<int>> simplemultitask::generate_nbest(const vector<vector<float>>& in_seq, 
      LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, 
      ComputationGraph& cg, int nbest_1_size, int nbest_2_size, int beamsize) {
  
  enc_fwd_lstm.new_graph(cg);
  enc_bwd_lstm.new_graph(cg);
  enc_compress_lstm_1.new_graph(cg);
  enc_compress_lstm_2.new_graph(cg);
  dec_1_lstm.new_graph(cg);
  dec_2_lstm.new_graph(cg);
  enc_fwd_lstm.start_new_sequence();
  enc_bwd_lstm.start_new_sequence();
  enc_compress_lstm_1.start_new_sequence();
  enc_compress_lstm_2.start_new_sequence();
  dec_1_lstm.start_new_sequence();
  dec_2_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(in_seq, cg);
  else
    embedded = embed_features(in_seq, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  
  // The n-best hypotheses for the first decoder
  vector<DecoderHypPtr> nbest_1, nbest_2;

  // These will hold the best sentences from the int and trg decoder
  vector<int> best_int, best_trg;

  // First decoder's beam search
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);
  Expression input_mat = concatenate_cols(encoded);
  Expression w1dt = w1 * input_mat;

  cg.incremental_forward(w1dt);

  vector<vector<Expression>> last_states(beamsize);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression init_vector = concatenate( { encoded.back(), last_output_embeddings });
  dec_1_lstm.add_input(init_vector);
  vector<Expression> last_s = dec_1_lstm.final_s();

  vector<int> init_sent;
  init_sent.push_back(kSOS);
  vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_s, init_sent)));

  int size_limit_ = in_seq.size() * 2;
  if (size_limit_ > 80) size_limit_ = in_seq.size();

  // Beam search 1st decoder
  for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
    // This vector will hold the best IDs
    vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

    // Iterate over the cyrrent beams and go one step forward
    for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
      // Get the current hypothesis
      DecoderHypPtr curr_hyp = curr_beam[hypid];
      // Get the current hypothesis sentence
      const vector<int>& sent = curr_hyp->GetSentence();
      
      // Do not expand a finished beam
      if (sent[sent_len] == kEOS) continue;

      Expression last_output_embeddings = lookup(cg, output_1_lookup, sent[sent_len]);
  
      // Perform the forward step on the decoder (after init with its last state)
      if (sent_len == 0){
        dec_1_lstm.start_new_sequence();
        dec_1_lstm.add_input(init_vector);
      }
      else
        dec_1_lstm.start_new_sequence(curr_hyp->GetStates());

      //concatenate input weighted by attention and decoder lstm state
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
          
      // And now get the last softmax
      Expression out_vector = log(softmax(w * dec_1_lstm.back() + b));
      // Add length normalization
      float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
      vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));

    // Add unknown word penalty
      if(INT_UNK_ID >= 0) probs[INT_UNK_ID] += int_unk_log_prob_;
    // Keep the final state for the continuation of the beam
      last_states[hypid] = dec_1_lstm.final_s();
      // Find the best IDs
      for(int wid = 0; wid < (int)probs.size(); wid++) {
        // The new score will be the current score + the softmax score
        dynet::real my_score = curr_hyp->GetScore() + probs[wid];
        // Now go through the beams from bottom to the beginning
        // and only keep the best <beamsize>
        int bid;
        for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
          next_beam_id[bid] = next_beam_id[bid-1];
        next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
      }
    }
    // Create the new hypotheses
    vector<DecoderHypPtr> next_beam;
    for(int i = 0; i < beamsize; i++) {
      dynet::real score = get<0>(next_beam_id[i]);
      int hypid = get<1>(next_beam_id[i]);
      int wid = get<2>(next_beam_id[i]);
      // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
      if(hypid == -1) break;
      // Add the last word to the sentence
      vector<int> next_sent = curr_beam[hypid]->GetSentence();
      next_sent.push_back(wid);
      DecoderHypPtr hyp(new DecoderHyp(score, last_states[hypid], next_sent));     
      // If we are done, add it to the nbest list
      if(wid == kEOS || sent_len == size_limit_) 
        nbest_1.push_back(hyp);
      // Add it do the next beams to be expanded
      next_beam.push_back(hyp);
    }

    // Substitute beams with the next ones 
    curr_beam = next_beam;
    // Check if we're done with search
    if(nbest_1.size() != 0) {
      sort(nbest_1.begin(), nbest_1.end());
      // trim to top n options
      if(nbest_1.size() > nbest_1_size)
        nbest_1.resize(nbest_1_size);
      // If we have no more beams to expand
      if(nbest_1.size() == nbest_1_size && (next_beam.size() == 0 || (*nbest_1.rbegin())->GetScore() >= next_beam[0]->GetScore()))
        break;
        // return nbest_1
    }
  }

  best_int = nbest_1[0]->GetSentence();

  // Second decoder
  dec_2_lstm.new_graph(cg);
  dec_2_lstm.start_new_sequence();
  Expression w_2 = parameter(cg, decoder_2_w);
  Expression b_2 = parameter(cg, decoder_2_b);
  Expression w1_2 = parameter(cg, attention_2_w1);
  Expression last_output_embeds = lookup(cg, output_2_lookup, kSOS);
  Expression w1dt_2 = w1_2 * input_mat;
  cg.incremental_forward(w1dt_2);
  vector<vector<Expression>> last_2_states(beamsize);
  Expression init_2_vector = concatenate( { encoded.back(), last_output_embeds } );
  dec_2_lstm.add_input( init_2_vector);
  vector<Expression> last_2_s = dec_2_lstm.final_s();
  vector<int> init_sent_2;
  init_sent_2.push_back(kSOS);
  vector<DecoderHypPtr> curr_beam_2(1, DecoderHypPtr(new DecoderHyp(0.0, last_2_s, init_sent_2))); 

  for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
    // This vector will hold the best IDs
    vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

    // Iterate over the cyrrent beams and go one step forward
    for(int hypid = 0; hypid < (int)curr_beam_2.size(); hypid++) {
      // Get the current hypothesis
      DecoderHypPtr curr_hyp = curr_beam_2[hypid];
      // Get the current hypothesis sentence
      const vector<int>& sent = curr_hyp->GetSentence();
      //if (sent_len != 0 && *sent.rbegin() == 0) continue;
      // Do not expand a finished beam
      if (sent[sent_len] == kEOS) continue;

      Expression last_output_embeds = lookup(cg, output_2_lookup, sent[sent_len]);

      // Perform the forward step on the decoder (after init with its last state)
      if (sent_len == 0){
        dec_2_lstm.start_new_sequence();
        dec_2_lstm.add_input(init_2_vector);
      }
      else
        dec_2_lstm.start_new_sequence(curr_hyp->GetStates());

      //concatenate input weighted by attention and decoder lstm state
      Expression att_weights_2 = attend_2(dec_2_lstm, w1dt_2, cg);
      Expression context_2 = input_mat * att_weights_2;
      Expression in_vector = concatenate( { context_2, last_output_embeds }); 
      dec_2_lstm.add_input(in_vector);
          
      // And now get the last softmax
      Expression out_vector = log(softmax(w_2 * dec_2_lstm.back() + b_2));
      // Add length normalization
      float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
      vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));

      // Add unknown word penalty
      if(TRG_UNK_ID >= 0) probs[TRG_UNK_ID] += trg_unk_log_prob_;
      // Keep the final state for the continuation of the beam
      last_2_states[hypid] = dec_2_lstm.final_s();
      // Find the best IDs
      for(int wid = 0; wid < (int)probs.size(); wid++) {
        // The new score will be the current score + the softmax score
        dynet::real my_score = curr_hyp->GetScore() + probs[wid];
        // Now go through the beams from bottom to the beginning
        // and only keep the best <beamsize>
        int bid;
        for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
          next_beam_id[bid] = next_beam_id[bid-1];
        next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
      }
    }
    // Create the new hypotheses
    vector<DecoderHypPtr> next_beam;
    for(int i = 0; i < beamsize; i++) {
      dynet::real score = get<0>(next_beam_id[i]);
      int hypid = get<1>(next_beam_id[i]);
      int wid = get<2>(next_beam_id[i]);
      // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
      if(hypid == -1) break;
      // Add the last word to the sentence
      vector<int> next_sent = curr_beam_2[hypid]->GetSentence();
      next_sent.push_back(wid);
      DecoderHypPtr hyp(new DecoderHyp(score, last_2_states[hypid], next_sent));     
      // If we are done, add it to the nbest list
      if(wid == kEOS || sent_len == size_limit_) 
        nbest_2.push_back(hyp);
      // Add it do the next beams to be expanded
      next_beam.push_back(hyp);
    }

    // Substitute beams with the next ones 
    curr_beam_2 = next_beam;
    // Check if we're done with search
    if(nbest_2.size() != 0) {
      sort(nbest_2.begin(), nbest_2.end());
      // trim to top n options
      if(nbest_2.size() > nbest_2_size)
        nbest_2.resize(nbest_2_size);
      // If we have no more beams to expand
      if (nbest_2.size() == nbest_2_size && (next_beam.size() == 0 || (*nbest_2.rbegin())->GetScore() >= next_beam[0]->GetScore()))
        break;
        // return nbest_2
    }
  } 

  best_trg = nbest_2[0]->GetSentence();

  return make_tuple(best_int, best_trg);

}




/* Definitions for simpleunitask model
 *
 *
 */

void simpleunitask::initialize(ParameterCollection& model) {
  
  ENC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  DEC_LSTM_NUM_OF_LAYERS = NUM_LAYERS;
  
  if (STACK_FEATS){
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE*8, STATE_SIZE1, model);
  }
  else{
    enc_fwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
    enc_bwd_lstm = LSTMBuilder(1, FEAT_SIZE, STATE_SIZE1, model);
  }
  
  enc_compress_lstm_1 = LSTMBuilder(1, STATE_SIZE1*2, STATE_SIZE, model);
  enc_compress_lstm_2 = LSTMBuilder(1, STATE_SIZE, STATE_SIZE, model);
    
  dec_1_lstm = LSTMBuilder(DEC_LSTM_NUM_OF_LAYERS, STATE_SIZE + EMBEDDINGS_SIZE , STATE_SIZE, model);
  
  enc_fwd_lstm.set_dropout(DROPOUT);
  enc_bwd_lstm.set_dropout(DROPOUT);
  enc_compress_lstm_1.set_dropout(DROPOUT);
  enc_compress_lstm_2.set_dropout(DROPOUT);
  dec_1_lstm.set_dropout(DROPOUT);
  
  output_1_lookup = model.add_lookup_parameters(INT_VOCAB_SIZE, { EMBEDDINGS_SIZE });
  
  attention_1_w1 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE });
  attention_1_w2 = model.add_parameters( { ATTENTION_SIZE, STATE_SIZE * DEC_LSTM_NUM_OF_LAYERS * 2 });
  attention_1_v = model.add_parameters( { 1, ATTENTION_SIZE });
  
  decoder_1_w = model.add_parameters( { INT_VOCAB_SIZE, STATE_SIZE });
  decoder_1_b = model.add_parameters( { INT_VOCAB_SIZE });

}


Expression simpleunitask::attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) {
  Expression w2 = parameter(cg, attention_1_w2);
  Expression v = parameter(cg, attention_1_v);
  Expression w2dt = w2 * concatenate(state.final_s());
  Expression unnormalized = transpose(v * tanh(colwise_add(w1dt, w2dt)));
  Expression att_weights = softmax(unnormalized);
  return att_weights;
}



// One sentence version
Expression simpleunitask::decode(LSTMBuilder& dec_1_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;

  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
  }
  Expression lossint = sum(loss);

  return lossint;
}

//Version that dumps ettentions
void simpleunitask::decode_attentions(LSTMBuilder& dec_1_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, ComputationGraph& cg) {
  
  vector<int> embeddings_1;

  for (int c : int_sentence) {
    embeddings_1.push_back(c);
  }
  embeddings_1.push_back(kEOS);

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;

  vector<Expression> decoder_1_states;

  vector<Expression> att_1;
  cout << "Attention 1" << endl;
  for (int c : embeddings_1) {
    //concatenate input weighted by attention and decoder lstm state
    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    cout << as_vector(att_weights_1.value()) << endl;
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    Expression probs = softmax(out_vector);
    last_output_embeddings = lookup(cg, output_1_lookup, c);
    loss.push_back(-log(pick(probs, c)));   
  }

  return;
}

// Batch version
Expression simpleunitask::decode(LSTMBuilder& dec_1_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, ComputationGraph& cg) {

  // The first output sentence should be the longest
  const unsigned max_len_1 = int_batch[0].size();
  
  vector<int> embeddings_1;

  // First decoder
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);

  // Initialize all sentences with SOS and get them as input
  vector<unsigned> words(int_batch.size(), kSOS);
  Expression input_mat = concatenate_cols(encoded);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression w1dt = w1 * input_mat;

  
  dec_1_lstm.add_input(concatenate( { encoded.back(), last_output_embeddings }));
  
  vector<Expression> loss;
  vector<Expression> decoder_1_states;
  vector<Expression> att_1;

  for(int t = 0; t <= max_len_1; t++) {
    for(size_t i = 0; i < int_batch.size(); i++)
      words[i] = (t < int_batch[i].size() ? int_batch[i][t] : kEOS);

    Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
    Expression context_1 = input_mat * att_weights_1;

    att_1.push_back(att_weights_1);
    
    Expression vector = concatenate( { context_1, last_output_embeddings }); 
    dec_1_lstm.add_input(vector);
    Expression out_vector = w * dec_1_lstm.back() + b;
    last_output_embeddings = lookup(cg, output_1_lookup, words);
    loss.push_back(pickneglogsoftmax(out_vector, words));   //Batch version for loss

  }

  Expression lossint = sum_batches(sum(loss));

  return lossint;
}


Expression simpleunitask::get_loss(const vector<vector<float>>&  src_sentence, const vector<int>&  int_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
    LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, ComputationGraph& cg) {

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(src_sentence, cg);
  else
    embedded = embed_features(src_sentence, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, encoded, int_sentence, cg);
}
// Batch version
Expression simpleunitask::get_loss(const vector<vector<vector<float>>>&  input_batch, const vector<vector<int>>&  int_batch, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
    LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, ComputationGraph& cg) { 
  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(input_batch, cg);
  else
    embedded = embed_features(input_batch, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  return decode(dec_1_lstm, encoded, int_batch, cg);
}

float simpleunitask::train(ParameterCollection& model, const vector<vector<vector<float>>>&  source_batch, const vector<vector<int>>& int_batch, AdamTrainer& trainer, dynet::real l_scale) {
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();
    Expression loss = get_loss(source_batch, int_batch, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, cg);
    float loss_value = as_scalar(cg.forward(loss)); //forward propagation
    cg.backward(loss);  //backward propagation
    trainer.update();   //update network weights
    return loss_value;
  
}

void simpleunitask::dump_attentions(ParameterCollection& model, const vector<vector<float>>& source_sentence, const vector<int>& int_sentence){
    ComputationGraph cg;
    enc_fwd_lstm.new_graph(cg);
    enc_bwd_lstm.new_graph(cg);
    enc_compress_lstm_1.new_graph(cg);
    enc_compress_lstm_2.new_graph(cg);
    dec_1_lstm.new_graph(cg);
    enc_fwd_lstm.start_new_sequence();
    enc_bwd_lstm.start_new_sequence();
    enc_compress_lstm_1.start_new_sequence();
    enc_compress_lstm_2.start_new_sequence();
    dec_1_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(source_sentence, cg);
  else
    embedded = embed_features(source_sentence, cg);
	vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
	
	decode_attentions(dec_1_lstm, encoded, int_sentence, cg);

	return;
}
    


float simpleunitask::test_dev(ParameterCollection& model, const vector<vector<float>>&  source_sentence, const vector<int>& int_sentence) {
    return test_dev_bleu(model, source_sentence, int_sentence);
  
}

float simpleunitask::test_dev_bleu(ParameterCollection& model, const vector<vector<float>>&  source_sentence, const vector<int>& int_sentence) {
    ComputationGraph cg;
    vector<int> outp_sent = generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, cg, 1, 1);
    return bleu(outp_sent, int_sentence);
}

void simpleunitask::test(ParameterCollection& model, const vector<vector<float>>&  source_sentence, int beamsize) {
  ComputationGraph cg;
  vector<int> best_int = generate_nbest(source_sentence, enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, dec_1_lstm, cg, beamsize ,beamsize);
  string out;
  for (auto c : best_int) {
	  out = out + int_d.convert(c) + " ";
  }
  cout << out << endl; 
  return;
}


vector<int> simpleunitask::generate_nbest(const vector<vector<float>>& in_seq, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm_1, LSTMBuilder& enc_compress_lstm_2, LSTMBuilder& dec_1_lstm, ComputationGraph& cg, int nbest_1_size, int beamsize) {
  
  enc_fwd_lstm.new_graph(cg);
  enc_bwd_lstm.new_graph(cg);
  enc_compress_lstm_1.new_graph(cg);
  enc_compress_lstm_2.new_graph(cg);
  dec_1_lstm.new_graph(cg);
  enc_fwd_lstm.start_new_sequence();
  enc_bwd_lstm.start_new_sequence();
  enc_compress_lstm_1.start_new_sequence();
  enc_compress_lstm_2.start_new_sequence();
  dec_1_lstm.start_new_sequence();

  vector<Expression> embedded;
  if (STACK_FEATS)
    embedded = embed_stack_features(in_seq, cg);
  else
    embedded = embed_features(in_seq, cg);
  vector<Expression> encoded = encode_features(enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm_1, enc_compress_lstm_2, embedded);
  
  // The n-best hypotheses for the first decoder
  vector<DecoderHypPtr> nbest_1, nbest_2;

  // These will hold the best sentences from the decoder
  vector<int> best_int;

  // First decoder's beam search
  Expression w = parameter(cg, decoder_1_w);
  Expression b = parameter(cg, decoder_1_b);
  Expression w1 = parameter(cg, attention_1_w1);
  Expression input_mat = concatenate_cols(encoded);
  Expression w1dt = w1 * input_mat;

  cg.incremental_forward(w1dt);

  vector<vector<Expression>> last_states(beamsize);

  Expression last_output_embeddings = lookup(cg, output_1_lookup, kSOS);
  Expression init_vector = concatenate( { encoded.back(), last_output_embeddings});
  dec_1_lstm.add_input(init_vector);
  vector<Expression> last_s = dec_1_lstm.final_s();

  vector<int> init_sent;
  init_sent.push_back(kSOS);
  vector<DecoderHypPtr> curr_beam(1, DecoderHypPtr(new DecoderHyp(0.0, last_s, init_sent)));

  int size_limit_ = in_seq.size() * 2;
  if (size_limit_ > 80) size_limit_ = in_seq.size();

  // Beam search 1st decoder
  for (int sent_len = 0; sent_len <= size_limit_; sent_len++) {
    // This vector will hold the best IDs
    vector<tuple<dynet::real,int,int> > next_beam_id(beamsize+1, tuple<dynet::real,int,int>(-400000,-1,-1));

    // Iterate over the cyrrent beams and go one step forward
    for(int hypid = 0; hypid < (int)curr_beam.size(); hypid++) {
      // Get the current hypothesis
      DecoderHypPtr curr_hyp = curr_beam[hypid];
      // Get the current hypothesis sentence
      const vector<int>& sent = curr_hyp->GetSentence();
      // Do not expand a finished beam
      if (sent[sent_len] == kEOS) continue;

      Expression last_output_embeddings = lookup(cg, output_1_lookup, sent[sent_len]);
  
      // Perform the forward step on the decoder (after init with its last state)
      if (sent_len == 0){
        dec_1_lstm.start_new_sequence();
        dec_1_lstm.add_input(init_vector);
      }
      else
        dec_1_lstm.start_new_sequence(curr_hyp->GetStates());

      //concatenate input weighted by attention and decoder lstm state
      Expression att_weights_1 = attend_1(dec_1_lstm, w1dt, cg);
      Expression context_1 = input_mat * att_weights_1;
      Expression in_vector = concatenate( { context_1, last_output_embeddings }); 
      dec_1_lstm.add_input(in_vector);
          
      // And now get the last softmax
      Expression out_vector = log(softmax(w * dec_1_lstm.back() + b));
	    // Add length normalization
	    float length_norm = pow(5 + sent_len, LENGTH_NORM_WEIGHT)/(pow(6,LENGTH_NORM_WEIGHT));
      vector<float> probs = as_vector(cg.incremental_forward(out_vector / length_norm));
	  

	  // Add unknown word penalty
      if(INT_UNK_ID >= 0) probs[INT_UNK_ID] += int_unk_log_prob_;
      // Keep the final state for the continuation of the beam
      last_states[hypid] = dec_1_lstm.final_s();
      // Find the best IDs
      for(int wid = 0; wid < (int)probs.size(); wid++) {
        // The new score will be the current score + the softmax score
        dynet::real my_score = curr_hyp->GetScore() + probs[wid];
        // Now go through the beams from bottom to the beginning
        // and only keep the best <beamsize>
        int bid;
        for(bid = beamsize; bid > 0 && my_score > get<0>(next_beam_id[bid-1]); bid--)
          next_beam_id[bid] = next_beam_id[bid-1];
        next_beam_id[bid] = tuple<dynet::real,int,int>(my_score,hypid,wid);
      }
    }
    // Create the new hypotheses
    vector<DecoderHypPtr> next_beam;
    for(int i = 0; i < beamsize; i++) {
      dynet::real score = get<0>(next_beam_id[i]);
      int hypid = get<1>(next_beam_id[i]);
      int wid = get<2>(next_beam_id[i]);
      // cerr << "Adding " << wid << " @ beam " << i << ": score=" << get<0>(next_beam_id[i]) - curr_beam[hypid]->GetScore() << endl;
      if(hypid == -1) break;
      // Add the last word to the sentence
      vector<int> next_sent = curr_beam[hypid]->GetSentence();
      next_sent.push_back(wid);
      DecoderHypPtr hyp(new DecoderHyp(score, last_states[hypid], next_sent));     
      // If we are done, add it to the nbest list
      if(wid == kEOS || sent_len == size_limit_) 
        nbest_1.push_back(hyp);
      // Add it do the next beams to be expanded
      next_beam.push_back(hyp);
    }

    // Substitute beams with the next ones 
    curr_beam = next_beam;
    // Check if we're done with search
    if(nbest_1.size() != 0) {
      sort(nbest_1.begin(), nbest_1.end());
      // trim to top n options
      if(nbest_1.size() > nbest_1_size)
        nbest_1.resize(nbest_1_size);
      // If we have no more beams to expand
      if(nbest_1.size() == nbest_1_size && (next_beam.size() == 0 || (*nbest_1.rbegin())->GetScore() >= next_beam[0]->GetScore()))
        break;
        // return nbest_1
    }
  }

  best_int = nbest_1[0]->GetSentence();
  return best_int;


}












